Skip to main content

temporalio_client/
worker.rs

1//! Contains types and logic for interactions between clients and Core/SDK workers
2
3use anyhow::bail;
4use parking_lot::RwLock;
5use rand::seq::SliceRandom;
6use std::{
7    collections::{
8        HashMap,
9        hash_map::Entry::{Occupied, Vacant},
10    },
11    sync::Arc,
12};
13use temporalio_common::{
14    protos::{
15        TaskToken,
16        temporal::api::{
17            worker::v1::WorkerHeartbeat, workflowservice::v1::PollWorkflowTaskQueueResponse,
18        },
19    },
20    worker::{WorkerDeploymentOptions, WorkerTaskTypes},
21};
22use uuid::Uuid;
23
24/// This trait represents a slot reserved for processing a WFT by a worker.
25#[cfg_attr(test, mockall::automock)]
26pub trait Slot {
27    /// Consumes this slot by dispatching a WFT to its worker. This can only be called once.
28    fn schedule_wft(
29        self: Box<Self>,
30        task: PollWorkflowTaskQueueResponse,
31    ) -> Result<(), anyhow::Error>;
32}
33
34/// Result of reserving a workflow task slot, including deployment options if applicable.
35pub(crate) struct SlotReservation {
36    /// The reserved slot for processing the workflow task
37    pub slot: Box<dyn Slot + Send>,
38    /// Worker deployment options, if the worker is using deployment-based versioning
39    pub deployment_options: Option<WorkerDeploymentOptions>,
40}
41
42#[derive(PartialEq, Eq, Hash, Debug, Clone)]
43struct SlotKey {
44    namespace: String,
45    task_queue: String,
46}
47
48impl SlotKey {
49    fn new(namespace: String, task_queue: String) -> SlotKey {
50        SlotKey {
51            namespace,
52            task_queue,
53        }
54    }
55}
56
57/// Information about a registered worker in the slot provider registry
58#[derive(Debug, Clone)]
59struct RegisteredWorkerInfo {
60    /// Unique identifier for this worker instance
61    worker_id: Uuid,
62    /// Optional deployment build ID for versioning
63    build_id: Option<String>,
64    /// Task types this worker can handle
65    task_types: WorkerTaskTypes,
66}
67
68impl RegisteredWorkerInfo {
69    fn new(worker_id: Uuid, build_id: Option<String>, task_types: WorkerTaskTypes) -> Self {
70        Self {
71            worker_id,
72            build_id,
73            task_types,
74        }
75    }
76}
77
78/// This is an inner class for [ClientWorkerSet] needed to hide the mutex.
79struct ClientWorkerSetImpl {
80    /// Maps slot keys to registered worker information
81    slot_providers: HashMap<SlotKey, Vec<RegisteredWorkerInfo>>,
82    /// Maps worker_instance_key to registered workers
83    all_workers: HashMap<Uuid, Arc<dyn ClientWorker + Send + Sync>>,
84    /// Maps namespace to shared worker for worker heartbeating
85    shared_worker: HashMap<String, Box<dyn SharedNamespaceWorkerTrait + Send + Sync>>,
86}
87
88impl ClientWorkerSetImpl {
89    /// Factory method.
90    fn new() -> Self {
91        Self {
92            slot_providers: Default::default(),
93            all_workers: Default::default(),
94            shared_worker: Default::default(),
95        }
96    }
97
98    fn try_reserve_wft_slot(
99        &self,
100        namespace: String,
101        task_queue: String,
102    ) -> Option<SlotReservation> {
103        let key = SlotKey::new(namespace, task_queue);
104        if let Some(worker_list) = self.slot_providers.get(&key) {
105            let workflow_workers: Vec<&RegisteredWorkerInfo> = worker_list
106                .iter()
107                .filter(|info| info.task_types.enable_workflows)
108                .collect();
109
110            for worker_id in Self::worker_ids_in_selection_order(&workflow_workers) {
111                if let Some(worker) = self.all_workers.get(&worker_id)
112                    && let Some(slot) = worker.try_reserve_wft_slot()
113                {
114                    let deployment_options = worker.deployment_options();
115                    return Some(SlotReservation {
116                        slot,
117                        deployment_options,
118                    });
119                }
120            }
121        }
122        None
123    }
124
125    fn worker_ids_in_selection_order(worker_list: &[&RegisteredWorkerInfo]) -> Vec<Uuid> {
126        // For tests we return workers in the order they're registered, so we can test
127        // the retry mechanism deterministically
128        if cfg!(test) {
129            worker_list.iter().map(|info| info.worker_id).collect()
130        } else {
131            let mut rng = rand::rng();
132            let mut shuffled: Vec<_> = worker_list.to_vec();
133            shuffled.shuffle(&mut rng);
134            shuffled.iter().map(|info| info.worker_id).collect()
135        }
136    }
137
138    fn register(
139        &mut self,
140        worker: Arc<dyn ClientWorker + Send + Sync>,
141        skip_client_worker_set_check: bool,
142    ) -> Result<(), anyhow::Error> {
143        let slot_key = SlotKey::new(
144            worker.namespace().to_string(),
145            worker.task_queue().to_string(),
146        );
147        let build_id = worker
148            .deployment_options()
149            .map(|opts| opts.version.build_id);
150        let task_types = worker.worker_task_types();
151
152        if !task_types.enable_workflows
153            && !task_types.enable_local_activities
154            && !task_types.enable_remote_activities
155            && !task_types.enable_nexus
156        {
157            bail!(
158                "Worker must have at least one capability enabled (workflows, activities, or nexus)"
159            );
160        }
161
162        if !task_types.enable_workflows && task_types.enable_local_activities {
163            bail!("Local activities cannot be enabled without workflows")
164        }
165
166        if !skip_client_worker_set_check
167            && let Some(existing_workers) = self.slot_providers.get(&slot_key)
168        {
169            for existing_worker_info in existing_workers {
170                if existing_worker_info.build_id.as_ref() == build_id.as_ref()
171                    && task_types.overlaps_with(&existing_worker_info.task_types)
172                {
173                    bail!(
174                        "Registration of multiple workers with overlapping worker task types \
175                        on the same namespace, task queue, and deployment build ID not allowed: \
176                        {slot_key:?}, worker_instance_key: {:?} \
177                        build_id: {build_id:?}, \
178                        new task types: {task_types:?}, \
179                        existing task types: {:?}.",
180                        existing_worker_info.task_types,
181                        worker.worker_instance_key()
182                    );
183                }
184            }
185        }
186
187        if worker.heartbeat_enabled()
188            && let Some(heartbeat_callback) = worker.heartbeat_callback()
189        {
190            let worker_instance_key = worker.worker_instance_key();
191            let namespace = worker.namespace().to_string();
192
193            let shared_worker = match self.shared_worker.entry(namespace.clone()) {
194                Occupied(o) => o.into_mut(),
195                Vacant(v) => {
196                    let shared_worker = worker.new_shared_namespace_worker()?;
197                    v.insert(shared_worker)
198                }
199            };
200            shared_worker.register_callback(
201                worker_instance_key,
202                WorkerCallbacks {
203                    heartbeat: heartbeat_callback,
204                    cancel_activity: worker.cancel_activity_callback(),
205                },
206            );
207        }
208
209        let worker_info =
210            RegisteredWorkerInfo::new(worker.worker_instance_key(), build_id, task_types);
211
212        match self.slot_providers.entry(slot_key.clone()) {
213            Occupied(o) => o.into_mut().push(worker_info),
214            Vacant(v) => {
215                v.insert(vec![worker_info]);
216            }
217        };
218
219        self.all_workers
220            .insert(worker.worker_instance_key(), worker);
221
222        Ok(())
223    }
224
225    /// Slot provider should be unregistered at the beginning of worker shutdown, in order to disable
226    /// eager workflow start.
227    fn unregister_slot_provider(&mut self, worker_instance_key: Uuid) -> Result<(), anyhow::Error> {
228        let worker = self.all_workers.get(&worker_instance_key).ok_or_else(|| {
229            anyhow::anyhow!("Worker not in all_workers during slot provider unregister")
230        })?;
231
232        let slot_key = SlotKey::new(
233            worker.namespace().to_string(),
234            worker.task_queue().to_string(),
235        );
236        if let Some(slot_vec) = self.slot_providers.get_mut(&slot_key) {
237            slot_vec.retain(|info| info.worker_id != worker_instance_key);
238            if slot_vec.is_empty() {
239                self.slot_providers.remove(&slot_key);
240            }
241        }
242        Ok(())
243    }
244
245    fn finalize_unregister(
246        &mut self,
247        worker_instance_key: Uuid,
248    ) -> Result<Arc<dyn ClientWorker + Send + Sync>, anyhow::Error> {
249        if let Some(worker) = self.all_workers.get(&worker_instance_key)
250            && let Some(slot_vec) = self.slot_providers.get(&SlotKey::new(
251                worker.namespace().to_string(),
252                worker.task_queue().to_string(),
253            ))
254            && slot_vec
255                .iter()
256                .any(|info| info.worker_id == worker_instance_key)
257        {
258            return Err(anyhow::anyhow!(
259                "Worker still in slot_providers during finalize"
260            ));
261        }
262
263        let worker = self
264            .all_workers
265            .remove(&worker_instance_key)
266            .ok_or_else(|| anyhow::anyhow!("Worker not found in all_workers"))?;
267
268        if let Some(w) = self.shared_worker.get_mut(worker.namespace()) {
269            let (callback, is_empty) = w.unregister_callback(worker.worker_instance_key());
270            if callback.is_some() && is_empty {
271                self.shared_worker.remove(worker.namespace());
272            }
273        }
274
275        Ok(worker)
276    }
277
278    #[cfg(test)]
279    fn num_providers(&self) -> usize {
280        self.slot_providers.values().map(|v| v.len()).sum()
281    }
282
283    #[cfg(test)]
284    fn num_heartbeat_workers(&self) -> usize {
285        self.shared_worker.values().map(|v| v.num_workers()).sum()
286    }
287}
288
289/// This trait represents a shared namespace worker that sends worker heartbeats and
290/// receives worker commands.
291pub trait SharedNamespaceWorkerTrait {
292    /// Namespace that the shared namespace worker is connected to.
293    fn namespace(&self) -> String;
294
295    /// Registers worker callbacks.
296    fn register_callback(&self, worker_instance_key: Uuid, callbacks: WorkerCallbacks);
297
298    /// Unregisters worker callbacks. Returns the callbacks removed, as well as a bool that
299    /// indicates if there are no remaining callbacks in the SharedNamespaceWorker, indicating
300    /// the shared worker itself can be shut down.
301    fn unregister_callback(&self, worker_instance_key: Uuid) -> (Option<WorkerCallbacks>, bool);
302
303    /// Returns the number of workers registered to this shared worker.
304    fn num_workers(&self) -> usize;
305}
306
307/// Enables local workers to make themselves visible to a shared client instance.
308///
309/// For slot managing, there can only be one worker registered per
310/// namespace+queue_name+connection, others will return an error.
311/// It also provides a convenient method to find compatible slots within the collection.
312pub struct ClientWorkerSet {
313    worker_grouping_key: Uuid,
314    worker_manager: RwLock<ClientWorkerSetImpl>,
315}
316
317impl Default for ClientWorkerSet {
318    fn default() -> Self {
319        Self::new()
320    }
321}
322
323impl ClientWorkerSet {
324    /// Factory method.
325    pub fn new() -> Self {
326        Self {
327            worker_grouping_key: Uuid::new_v4(),
328            worker_manager: RwLock::new(ClientWorkerSetImpl::new()),
329        }
330    }
331
332    /// Try to reserve a compatible processing slot in any of the registered workers.
333    /// Returns the slot and the worker's deployment options (if using deployment-based versioning).
334    pub(crate) fn try_reserve_wft_slot(
335        &self,
336        namespace: String,
337        task_queue: String,
338    ) -> Option<SlotReservation> {
339        self.worker_manager
340            .read()
341            .try_reserve_wft_slot(namespace, task_queue)
342    }
343
344    /// Register a local worker that can provide WFT processing slots and potentially worker heartbeating.
345    pub fn register_worker(
346        &self,
347        worker: Arc<dyn ClientWorker + Send + Sync>,
348        skip_client_worker_set_check: bool,
349    ) -> Result<(), anyhow::Error> {
350        self.worker_manager
351            .write()
352            .register(worker, skip_client_worker_set_check)
353    }
354
355    /// Disables Eager Workflow Start for this worker. This must be called before
356    /// `finalize_unregister`, otherwise `finalize_unregister` will return an err.
357    pub fn unregister_slot_provider(&self, worker_instance_key: Uuid) -> Result<(), anyhow::Error> {
358        self.worker_manager
359            .write()
360            .unregister_slot_provider(worker_instance_key)
361    }
362
363    /// Finalizes unregistering of worker from client. This must be called at the end of worker
364    /// shutdown in order to finalize shutdown for worker heartbeat properly. Must call after
365    /// `unregister_slot_provider`, otherwise an err will be returned.
366    pub fn finalize_unregister(
367        &self,
368        worker_instance_key: Uuid,
369    ) -> Result<Arc<dyn ClientWorker + Send + Sync>, anyhow::Error> {
370        self.worker_manager
371            .write()
372            .finalize_unregister(worker_instance_key)
373    }
374
375    /// Returns the worker grouping key, which is unique for each worker.
376    pub fn worker_grouping_key(&self) -> Uuid {
377        self.worker_grouping_key
378    }
379
380    #[cfg(test)]
381    /// Returns (num_providers, num_buckets), where a bucket key is namespace+task_queue.
382    /// There is only one provider per bucket so `num_providers` should be equal to `num_buckets`.
383    pub fn num_providers(&self) -> usize {
384        self.worker_manager.read().num_providers()
385    }
386
387    #[cfg(test)]
388    /// Returns the total number of heartbeat workers registered across all namespaces.
389    pub fn num_heartbeat_workers(&self) -> usize {
390        self.worker_manager.read().num_heartbeat_workers()
391    }
392}
393
394impl std::fmt::Debug for ClientWorkerSet {
395    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
396        f.debug_struct("ClientWorkerSet")
397            .field("worker_grouping_key", &self.worker_grouping_key)
398            .finish()
399    }
400}
401
402/// Contains a worker heartbeat callback, wrapped for mocking
403pub type HeartbeatCallback = Arc<dyn Fn() -> WorkerHeartbeat + Send + Sync>;
404
405/// Callback to cancel an activity by task token. Returns true if the activity was found.
406pub type CancelActivityCallback = Arc<dyn Fn(TaskToken) -> bool + Send + Sync>;
407
408/// Bundles all per-worker callbacks registered with the SharedNamespaceWorker.
409pub struct WorkerCallbacks {
410    /// Callback to collect heartbeat data from the worker.
411    pub heartbeat: HeartbeatCallback,
412    /// Callback to cancel an activity by task token.
413    pub cancel_activity: Option<CancelActivityCallback>,
414}
415
416/// Represents a complete worker that can handle both slot management
417/// and worker heartbeat functionality.
418#[cfg_attr(test, mockall::automock)]
419pub trait ClientWorker: Send + Sync {
420    /// The namespace this worker operates in
421    fn namespace(&self) -> &str;
422
423    /// The task queue this worker listens to
424    fn task_queue(&self) -> &str;
425
426    /// Try to reserve a slot for workflow task processing.
427    ///
428    /// This method should return `Some(slot)` if a workflow task slot is available,
429    /// or `None` if all slots are currently in use. The returned slot will be used
430    /// to process exactly one workflow task.
431    fn try_reserve_wft_slot(&self) -> Option<Box<dyn Slot + Send>>;
432
433    /// Get the worker deployment options for this worker, if using deployment-based versioning.
434    fn deployment_options(&self) -> Option<WorkerDeploymentOptions>;
435
436    /// Unique identifier for this worker instance.
437    /// This must be stable across the worker's lifetime and unique per instance.
438    fn worker_instance_key(&self) -> Uuid;
439
440    /// Indicates if worker heartbeating is enabled for this client worker.
441    fn heartbeat_enabled(&self) -> bool;
442
443    /// Returns the heartbeat callback that can be used to get WorkerHeartbeat data.
444    fn heartbeat_callback(&self) -> Option<HeartbeatCallback>;
445
446    /// Returns a callback that can cancel an activity by task token.
447    fn cancel_activity_callback(&self) -> Option<CancelActivityCallback>;
448
449    /// Creates a new worker that implements the [SharedNamespaceWorkerTrait]
450    fn new_shared_namespace_worker(
451        &self,
452    ) -> Result<Box<dyn SharedNamespaceWorkerTrait + Send + Sync>, anyhow::Error>;
453
454    /// Returns the task types this worker can handle
455    fn worker_task_types(&self) -> WorkerTaskTypes;
456}
457
458#[cfg(test)]
459mod tests {
460    use super::*;
461
462    fn new_mock_slot(with_error: bool) -> Box<MockSlot> {
463        let mut mock_slot = MockSlot::new();
464        if with_error {
465            mock_slot
466                .expect_schedule_wft()
467                .returning(|_| Err(anyhow::anyhow!("Changed my mind")));
468        } else {
469            mock_slot.expect_schedule_wft().returning(|_| Ok(()));
470        }
471        Box::new(mock_slot)
472    }
473
474    fn new_mock_provider(
475        namespace: String,
476        task_queue: String,
477        with_error: bool,
478        no_slots: bool,
479        heartbeat_enabled: bool,
480    ) -> MockClientWorker {
481        let mut mock_provider = MockClientWorker::new();
482        mock_provider
483            .expect_try_reserve_wft_slot()
484            .returning(move || {
485                if no_slots {
486                    None
487                } else {
488                    Some(new_mock_slot(with_error))
489                }
490            });
491        mock_provider.expect_namespace().return_const(namespace);
492        mock_provider.expect_task_queue().return_const(task_queue);
493        mock_provider.expect_deployment_options().return_const(None);
494        mock_provider
495            .expect_heartbeat_enabled()
496            .return_const(heartbeat_enabled);
497        mock_provider
498            .expect_worker_instance_key()
499            .return_const(Uuid::new_v4());
500        mock_provider
501            .expect_worker_task_types()
502            .return_const(WorkerTaskTypes {
503                enable_workflows: true,
504                enable_local_activities: true,
505                enable_remote_activities: true,
506                enable_nexus: true,
507            });
508        mock_provider
509    }
510
511    #[test]
512    fn reserve_wft_slot_retries_another_worker_when_first_has_no_slot() {
513        let mut manager = ClientWorkerSetImpl::new();
514        let namespace = "retry_namespace".to_string();
515        let task_queue = "retry_queue".to_string();
516
517        let failing_worker_id = Uuid::new_v4();
518        let mut failing_worker = MockClientWorker::new();
519        failing_worker
520            .expect_try_reserve_wft_slot()
521            .times(1)
522            .returning(|| None);
523        failing_worker
524            .expect_namespace()
525            .return_const(namespace.clone());
526        failing_worker
527            .expect_task_queue()
528            .return_const(task_queue.clone());
529        failing_worker
530            .expect_deployment_options()
531            .return_const(WorkerDeploymentOptions {
532                version: temporalio_common::worker::WorkerDeploymentVersion {
533                    deployment_name: "test-deployment".to_string(),
534                    build_id: "build-fail".to_string(),
535                },
536                use_worker_versioning: true,
537                default_versioning_behavior: None,
538            });
539        failing_worker
540            .expect_worker_instance_key()
541            .return_const(failing_worker_id);
542        failing_worker
543            .expect_heartbeat_enabled()
544            .return_const(false);
545        failing_worker
546            .expect_worker_task_types()
547            .return_const(WorkerTaskTypes {
548                enable_workflows: true,
549                enable_local_activities: true,
550                enable_remote_activities: true,
551                enable_nexus: true,
552            });
553
554        let succeeding_worker_id = Uuid::new_v4();
555        let mut succeeding_worker = MockClientWorker::new();
556        succeeding_worker
557            .expect_try_reserve_wft_slot()
558            .times(1)
559            .returning(|| Some(new_mock_slot(false)));
560        succeeding_worker
561            .expect_namespace()
562            .return_const(namespace.clone());
563        succeeding_worker
564            .expect_task_queue()
565            .return_const(task_queue.clone());
566        let success_deployment_options = WorkerDeploymentOptions {
567            version: temporalio_common::worker::WorkerDeploymentVersion {
568                deployment_name: "test-deployment".to_string(),
569                build_id: "build-success".to_string(),
570            },
571            use_worker_versioning: true,
572            default_versioning_behavior: None,
573        };
574        succeeding_worker
575            .expect_deployment_options()
576            .return_const(success_deployment_options.clone());
577        succeeding_worker
578            .expect_worker_instance_key()
579            .return_const(succeeding_worker_id);
580        succeeding_worker
581            .expect_heartbeat_enabled()
582            .return_const(false);
583        succeeding_worker
584            .expect_worker_task_types()
585            .return_const(WorkerTaskTypes {
586                enable_workflows: true,
587                enable_local_activities: true,
588                enable_remote_activities: true,
589                enable_nexus: true,
590            });
591
592        manager
593            .register(Arc::new(failing_worker), false)
594            .expect("failing worker registration succeeds");
595        manager
596            .register(Arc::new(succeeding_worker), false)
597            .expect("succeeding worker registration succeeds");
598
599        let reservation = manager.try_reserve_wft_slot(namespace.clone(), task_queue.clone());
600
601        let reservation_deployment_options = reservation
602            .expect("succeeding worker was used after failing worker failed")
603            .deployment_options
604            .unwrap();
605        assert_eq!(
606            reservation_deployment_options, success_deployment_options,
607            "deployment options bubble through from succeeding worker"
608        );
609    }
610
611    #[test]
612    fn reserve_wft_slot_retries_respects_slot_boundary() {
613        let mut manager = ClientWorkerSetImpl::new();
614        let namespace = "retry_namespace".to_string();
615        let task_queue = "retry_queue".to_string();
616
617        let failing_worker_id = Uuid::new_v4();
618        let mut failing_worker = MockClientWorker::new();
619        failing_worker
620            .expect_try_reserve_wft_slot()
621            .times(1)
622            .returning(|| None);
623        failing_worker
624            .expect_namespace()
625            .return_const(namespace.clone());
626        failing_worker
627            .expect_task_queue()
628            .return_const(task_queue.clone());
629        failing_worker
630            .expect_deployment_options()
631            .return_const(WorkerDeploymentOptions {
632                version: temporalio_common::worker::WorkerDeploymentVersion {
633                    deployment_name: "test-deployment".to_string(),
634                    build_id: "build-fail".to_string(),
635                },
636                use_worker_versioning: true,
637                default_versioning_behavior: None,
638            });
639        failing_worker
640            .expect_worker_instance_key()
641            .return_const(failing_worker_id);
642        failing_worker
643            .expect_heartbeat_enabled()
644            .return_const(false);
645        failing_worker
646            .expect_worker_task_types()
647            .return_const(WorkerTaskTypes {
648                enable_workflows: true,
649                enable_local_activities: true,
650                enable_remote_activities: true,
651                enable_nexus: true,
652            });
653
654        // On a separate task queue
655        let succeeding_worker_id = Uuid::new_v4();
656        let mut succeeding_worker = MockClientWorker::new();
657        succeeding_worker.expect_try_reserve_wft_slot().times(0);
658        succeeding_worker
659            .expect_namespace()
660            .return_const(namespace.clone());
661        succeeding_worker
662            .expect_task_queue()
663            .return_const("other_task_queue".to_string());
664        succeeding_worker
665            .expect_deployment_options()
666            .return_const(None);
667        succeeding_worker
668            .expect_worker_instance_key()
669            .return_const(succeeding_worker_id);
670        succeeding_worker
671            .expect_heartbeat_enabled()
672            .return_const(false);
673        succeeding_worker
674            .expect_worker_task_types()
675            .return_const(WorkerTaskTypes {
676                enable_workflows: true,
677                enable_local_activities: true,
678                enable_remote_activities: true,
679                enable_nexus: true,
680            });
681
682        manager
683            .register(Arc::new(failing_worker), false)
684            .expect("failing worker registration succeeds");
685        manager
686            .register(Arc::new(succeeding_worker), false)
687            .expect("succeeding worker registration succeeds");
688
689        let reservation = manager.try_reserve_wft_slot(namespace.clone(), task_queue.clone());
690        assert!(
691            reservation.is_none(),
692            "succeeding_worker should not be picked due to it being on a separate task queue"
693        );
694    }
695
696    #[test]
697    fn registry_keeps_one_provider_per_namespace() {
698        let manager = ClientWorkerSet::new();
699        let mut worker_keys = vec![];
700        let mut successful_registrations = 0;
701
702        for i in 0..10 {
703            let namespace = format!("myId{}", i % 3);
704            let mock_provider =
705                new_mock_provider(namespace, "bar_q".to_string(), false, false, false);
706            let worker_instance_key = mock_provider.worker_instance_key();
707
708            let result = manager.register_worker(Arc::new(mock_provider), false);
709            if let Err(err) = result {
710                // Should get error for overlapping worker task types
711                assert!(err.to_string().contains(
712                    "Registration of multiple workers with overlapping worker task types"
713                ));
714            } else {
715                successful_registrations += 1;
716                worker_keys.push(worker_instance_key);
717            }
718        }
719
720        assert_eq!(successful_registrations, 3);
721        assert_eq!(3, manager.num_providers());
722
723        let count = worker_keys.iter().fold(0, |count, key| {
724            manager.unregister_slot_provider(*key).unwrap();
725            manager.finalize_unregister(*key).unwrap();
726            // expect error since worker is already unregistered
727            let result = manager.unregister_slot_provider(*key);
728            assert!(result.is_err());
729            let result = manager.finalize_unregister(*key);
730            assert!(result.is_err());
731            count + 1
732        });
733        assert_eq!(3, count);
734        assert_eq!(0, manager.num_providers());
735    }
736
737    struct MockSharedNamespaceWorker {
738        namespace: String,
739        callbacks: Arc<RwLock<HashMap<Uuid, WorkerCallbacks>>>,
740    }
741
742    impl std::fmt::Debug for MockSharedNamespaceWorker {
743        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
744            f.debug_struct("MockSharedNamespaceWorker")
745                .field("namespace", &self.namespace)
746                .field("callbacks_count", &self.callbacks.read().len())
747                .finish()
748        }
749    }
750
751    impl MockSharedNamespaceWorker {
752        fn new(namespace: String) -> Self {
753            Self {
754                namespace,
755                callbacks: Arc::new(RwLock::new(HashMap::new())),
756            }
757        }
758    }
759
760    impl SharedNamespaceWorkerTrait for MockSharedNamespaceWorker {
761        fn namespace(&self) -> String {
762            self.namespace.clone()
763        }
764
765        fn register_callback(&self, worker_instance_key: Uuid, callbacks: WorkerCallbacks) {
766            self.callbacks
767                .write()
768                .insert(worker_instance_key, callbacks);
769        }
770
771        fn unregister_callback(
772            &self,
773            worker_instance_key: Uuid,
774        ) -> (Option<WorkerCallbacks>, bool) {
775            let mut callbacks = self.callbacks.write();
776            let callback = callbacks.remove(&worker_instance_key);
777            let is_empty = callbacks.is_empty();
778            (callback, is_empty)
779        }
780
781        fn num_workers(&self) -> usize {
782            self.callbacks.read().len()
783        }
784    }
785
786    fn new_mock_provider_with_heartbeat(
787        namespace: String,
788        task_queue: String,
789        heartbeat_enabled: bool,
790        build_id: Option<String>,
791    ) -> MockClientWorker {
792        let mut mock_provider = MockClientWorker::new();
793        mock_provider
794            .expect_try_reserve_wft_slot()
795            .returning(|| Some(new_mock_slot(false)));
796        mock_provider
797            .expect_namespace()
798            .return_const(namespace.clone());
799        mock_provider.expect_task_queue().return_const(task_queue);
800        mock_provider
801            .expect_heartbeat_enabled()
802            .return_const(heartbeat_enabled);
803        mock_provider
804            .expect_worker_instance_key()
805            .return_const(Uuid::new_v4());
806        let deployment_name = "test-deployment".to_string();
807        let build_id_for_closure = build_id.clone();
808        mock_provider
809            .expect_deployment_options()
810            .returning(move || {
811                build_id_for_closure
812                    .as_ref()
813                    .map(|build_id| WorkerDeploymentOptions {
814                        version: temporalio_common::worker::WorkerDeploymentVersion {
815                            deployment_name: deployment_name.clone(),
816                            build_id: build_id.clone(),
817                        },
818                        use_worker_versioning: true,
819                        default_versioning_behavior: None,
820                    })
821            });
822
823        if heartbeat_enabled {
824            mock_provider
825                .expect_heartbeat_callback()
826                .returning(|| Some(Arc::new(WorkerHeartbeat::default)));
827            mock_provider
828                .expect_cancel_activity_callback()
829                .returning(|| None);
830
831            let namespace_clone = namespace.clone();
832            mock_provider
833                .expect_new_shared_namespace_worker()
834                .returning(move || {
835                    Ok(Box::new(MockSharedNamespaceWorker::new(
836                        namespace_clone.clone(),
837                    )))
838                });
839        }
840
841        mock_provider
842            .expect_worker_task_types()
843            .return_const(WorkerTaskTypes {
844                enable_workflows: true,
845                enable_local_activities: true,
846                enable_remote_activities: true,
847                enable_nexus: true,
848            });
849
850        mock_provider
851    }
852
853    #[test]
854    fn duplicate_namespace_task_queue_registration_fails() {
855        let manager = ClientWorkerSet::new();
856
857        let worker1 = new_mock_provider_with_heartbeat(
858            "test_namespace".to_string(),
859            "test_queue".to_string(),
860            true,
861            None,
862        );
863
864        // Same namespace+task_queue but different worker instance
865        let worker2 = new_mock_provider_with_heartbeat(
866            "test_namespace".to_string(),
867            "test_queue".to_string(),
868            true,
869            None,
870        );
871
872        manager.register_worker(Arc::new(worker1), false).unwrap();
873
874        // second worker register should fail due to overlapping worker task types
875        let result = manager.register_worker(Arc::new(worker2), false);
876        assert!(result.is_err());
877        assert!(
878            result
879                .unwrap_err()
880                .to_string()
881                .contains("Registration of multiple workers with overlapping worker task types")
882        );
883
884        assert_eq!(1, manager.num_providers());
885        assert_eq!(manager.num_heartbeat_workers(), 1);
886
887        let impl_ref = manager.worker_manager.read();
888        assert_eq!(impl_ref.shared_worker.len(), 1);
889        assert!(impl_ref.shared_worker.contains_key("test_namespace"));
890    }
891
892    #[test]
893    fn duplicate_namespace_with_different_build_ids_succeeds() {
894        let manager = ClientWorkerSet::new();
895        let namespace = "test_namespace".to_string();
896        let task_queue = "test_queue".to_string();
897
898        let worker1 =
899            new_mock_provider_with_heartbeat(namespace.clone(), task_queue.clone(), false, None);
900        let worker1_instance_key = worker1.worker_instance_key();
901        let worker2 = new_mock_provider_with_heartbeat(
902            namespace.clone(),
903            task_queue.clone(),
904            false,
905            Some("build-1".to_string()),
906        );
907        let worker2_instance_key = worker2.worker_instance_key();
908        let worker3 =
909            new_mock_provider_with_heartbeat(namespace.clone(), task_queue.clone(), false, None);
910        let worker4 = new_mock_provider_with_heartbeat(
911            namespace.clone(),
912            task_queue.clone(),
913            false,
914            Some("build-1".to_string()),
915        );
916
917        manager.register_worker(Arc::new(worker1), false).unwrap();
918
919        manager
920            .register_worker(Arc::new(worker2), false)
921            .expect("worker with new build ID should register");
922        assert_eq!(2, manager.num_providers());
923
924        assert!(
925            manager
926                .register_worker(Arc::new(worker3), false)
927                .unwrap_err()
928                .to_string()
929                .contains("Registration of multiple workers with overlapping worker task types")
930        );
931
932        assert!(
933            manager
934                .register_worker(Arc::new(worker4), false)
935                .unwrap_err()
936                .to_string()
937                .contains("Registration of multiple workers with overlapping worker task types")
938        );
939        assert_eq!(2, manager.num_providers());
940
941        {
942            let impl_ref = manager.worker_manager.read();
943            let slot_key = SlotKey::new(namespace.clone(), task_queue.clone());
944            let providers = impl_ref
945                .slot_providers
946                .get(&slot_key)
947                .expect("slot providers should exist for namespace/task queue");
948            assert_eq!(2, providers.len());
949
950            assert_eq!(providers[0].worker_id, worker1_instance_key);
951            assert_eq!(providers[0].build_id, None);
952            assert_eq!(providers[1].worker_id, worker2_instance_key);
953            assert_eq!(providers[1].build_id, Some("build-1".to_string()));
954        }
955
956        manager
957            .unregister_slot_provider(worker2_instance_key)
958            .unwrap();
959        manager.finalize_unregister(worker2_instance_key).unwrap();
960
961        {
962            let impl_ref = manager.worker_manager.read();
963            let slot_key = SlotKey::new(namespace.clone(), task_queue.clone());
964            let providers = impl_ref
965                .slot_providers
966                .get(&slot_key)
967                .expect("slot providers should exist for namespace/task queue");
968
969            assert_eq!(1, providers.len());
970            assert_eq!(providers[0].worker_id, worker1_instance_key);
971            assert_eq!(providers[0].build_id, None);
972        }
973    }
974
975    #[test]
976    fn multiple_workers_same_namespace_share_heartbeat_manager() {
977        let manager = ClientWorkerSet::new();
978
979        let worker1 = new_mock_provider_with_heartbeat(
980            "shared_namespace".to_string(),
981            "queue1".to_string(),
982            true,
983            None,
984        );
985
986        // Same namespace but different task queue
987        let worker2 = new_mock_provider_with_heartbeat(
988            "shared_namespace".to_string(),
989            "queue2".to_string(),
990            true,
991            None,
992        );
993
994        manager.register_worker(Arc::new(worker1), false).unwrap();
995        manager.register_worker(Arc::new(worker2), false).unwrap();
996
997        assert_eq!(2, manager.num_providers());
998        assert_eq!(manager.num_heartbeat_workers(), 2);
999
1000        let impl_ref = manager.worker_manager.read();
1001        assert_eq!(impl_ref.shared_worker.len(), 1);
1002        assert!(impl_ref.shared_worker.contains_key("shared_namespace"));
1003
1004        let shared_worker = impl_ref.shared_worker.get("shared_namespace").unwrap();
1005        assert_eq!(shared_worker.namespace(), "shared_namespace");
1006    }
1007
1008    #[test]
1009    fn different_namespaces_get_separate_heartbeat_managers() {
1010        let manager = ClientWorkerSet::new();
1011        let worker1 = new_mock_provider_with_heartbeat(
1012            "namespace1".to_string(),
1013            "queue1".to_string(),
1014            true,
1015            None,
1016        );
1017        let worker2 = new_mock_provider_with_heartbeat(
1018            "namespace2".to_string(),
1019            "queue1".to_string(),
1020            true,
1021            None,
1022        );
1023
1024        manager.register_worker(Arc::new(worker1), false).unwrap();
1025        manager.register_worker(Arc::new(worker2), false).unwrap();
1026
1027        assert_eq!(2, manager.num_providers());
1028        assert_eq!(manager.num_heartbeat_workers(), 2);
1029
1030        let impl_ref = manager.worker_manager.read();
1031        assert_eq!(impl_ref.num_heartbeat_workers(), 2);
1032        assert!(impl_ref.shared_worker.contains_key("namespace1"));
1033        assert!(impl_ref.shared_worker.contains_key("namespace2"));
1034    }
1035
1036    #[test]
1037    fn unregister_heartbeat_workers_cleans_up_shared_worker_when_last_removed() {
1038        let manager = ClientWorkerSet::new();
1039
1040        // Create two workers with same namespace but different task queues
1041        let worker1 = new_mock_provider_with_heartbeat(
1042            "test_namespace".to_string(),
1043            "queue1".to_string(),
1044            true,
1045            None,
1046        );
1047        let worker2 = new_mock_provider_with_heartbeat(
1048            "test_namespace".to_string(),
1049            "queue2".to_string(),
1050            true,
1051            None,
1052        );
1053        let worker_instance_key1 = worker1.worker_instance_key();
1054        let worker_instance_key2 = worker2.worker_instance_key();
1055
1056        assert_ne!(worker_instance_key1, worker_instance_key2);
1057
1058        manager.register_worker(Arc::new(worker1), false).unwrap();
1059        manager.register_worker(Arc::new(worker2), false).unwrap();
1060
1061        // Verify initial state: 2 slot providers, 2 heartbeat workers, 1 shared worker
1062        assert_eq!(2, manager.num_providers());
1063        assert_eq!(manager.num_heartbeat_workers(), 2);
1064
1065        let impl_ref = manager.worker_manager.read();
1066        assert_eq!(impl_ref.shared_worker.len(), 1);
1067        assert!(impl_ref.shared_worker.contains_key("test_namespace"));
1068        assert_eq!(
1069            impl_ref
1070                .shared_worker
1071                .get("test_namespace")
1072                .unwrap()
1073                .num_workers(),
1074            2
1075        );
1076        drop(impl_ref);
1077
1078        // Unregister first worker
1079        manager
1080            .unregister_slot_provider(worker_instance_key1)
1081            .unwrap();
1082        manager.finalize_unregister(worker_instance_key1).unwrap();
1083
1084        // After unregistering first worker: 1 slot provider, 1 heartbeat worker, shared worker still exists
1085        assert_eq!(1, manager.num_providers());
1086        assert_eq!(manager.num_heartbeat_workers(), 1);
1087
1088        let impl_ref = manager.worker_manager.read();
1089        assert_eq!(impl_ref.num_heartbeat_workers(), 1); // SharedNamespaceWorker still exists
1090        assert!(impl_ref.shared_worker.contains_key("test_namespace"));
1091        assert_eq!(
1092            impl_ref
1093                .shared_worker
1094                .get("test_namespace")
1095                .unwrap()
1096                .num_workers(),
1097            1
1098        );
1099        drop(impl_ref);
1100
1101        // Unregister second worker
1102        manager
1103            .unregister_slot_provider(worker_instance_key2)
1104            .unwrap();
1105        manager.finalize_unregister(worker_instance_key2).unwrap();
1106
1107        // After unregistering last worker: 0 slot providers, 0 heartbeat workers, shared worker is removed
1108        assert_eq!(0, manager.num_providers());
1109        assert_eq!(manager.num_heartbeat_workers(), 0);
1110
1111        let impl_ref = manager.worker_manager.read();
1112        assert_eq!(impl_ref.shared_worker.len(), 0); // SharedNamespaceWorker is cleaned up
1113        assert!(!impl_ref.shared_worker.contains_key("test_namespace"));
1114    }
1115
1116    #[test]
1117    fn workflow_and_activity_only_workers_coexist() {
1118        let manager = ClientWorkerSet::new();
1119        let namespace = "test_namespace".to_string();
1120        let task_queue = "test_queue".to_string();
1121
1122        let mut workflow_nexus_worker = MockClientWorker::new();
1123        workflow_nexus_worker
1124            .expect_namespace()
1125            .return_const(namespace.clone());
1126        workflow_nexus_worker
1127            .expect_task_queue()
1128            .return_const(task_queue.clone());
1129        workflow_nexus_worker
1130            .expect_deployment_options()
1131            .return_const(None);
1132        workflow_nexus_worker
1133            .expect_worker_instance_key()
1134            .return_const(Uuid::new_v4());
1135        workflow_nexus_worker
1136            .expect_heartbeat_enabled()
1137            .return_const(false);
1138        workflow_nexus_worker
1139            .expect_worker_task_types()
1140            .return_const(WorkerTaskTypes {
1141                enable_workflows: true,
1142                enable_local_activities: false,
1143                enable_remote_activities: false,
1144                enable_nexus: true,
1145            });
1146
1147        let mut activity_worker = MockClientWorker::new();
1148        activity_worker
1149            .expect_namespace()
1150            .return_const(namespace.clone());
1151        activity_worker
1152            .expect_task_queue()
1153            .return_const(task_queue.clone());
1154        activity_worker
1155            .expect_deployment_options()
1156            .return_const(None);
1157        activity_worker
1158            .expect_worker_instance_key()
1159            .return_const(Uuid::new_v4());
1160        activity_worker
1161            .expect_heartbeat_enabled()
1162            .return_const(false);
1163        activity_worker
1164            .expect_worker_task_types()
1165            .return_const(WorkerTaskTypes {
1166                enable_workflows: false,
1167                enable_local_activities: false,
1168                enable_remote_activities: true,
1169                enable_nexus: false,
1170            });
1171        activity_worker.expect_try_reserve_wft_slot().times(0); // Should not be called for activity-only worker
1172
1173        manager
1174            .register_worker(Arc::new(workflow_nexus_worker), false)
1175            .expect("workflow-nexus worker should register");
1176        manager
1177            .register_worker(Arc::new(activity_worker), false)
1178            .expect("activity-only worker should register");
1179
1180        assert_eq!(2, manager.num_providers());
1181    }
1182
1183    #[test]
1184    fn overlapping_capabilities_rejected() {
1185        let manager = ClientWorkerSet::new();
1186        let namespace = "test_namespace".to_string();
1187        let task_queue = "test_queue".to_string();
1188
1189        // workflow+activity worker
1190        let mut worker1 = MockClientWorker::new();
1191        worker1.expect_namespace().return_const(namespace.clone());
1192        worker1.expect_task_queue().return_const(task_queue.clone());
1193        worker1.expect_deployment_options().return_const(None);
1194        worker1
1195            .expect_worker_instance_key()
1196            .return_const(Uuid::new_v4());
1197        worker1.expect_heartbeat_enabled().return_const(false);
1198        worker1
1199            .expect_worker_task_types()
1200            .return_const(WorkerTaskTypes {
1201                enable_workflows: true,
1202                enable_local_activities: true,
1203                enable_remote_activities: true,
1204                enable_nexus: false,
1205            });
1206
1207        // workflow+activity worker
1208        let mut worker2 = MockClientWorker::new();
1209        worker2.expect_namespace().return_const(namespace.clone());
1210        worker2.expect_task_queue().return_const(task_queue.clone());
1211        worker2.expect_deployment_options().return_const(None);
1212        worker2
1213            .expect_worker_instance_key()
1214            .return_const(Uuid::new_v4());
1215        worker2.expect_heartbeat_enabled().return_const(false);
1216        worker2
1217            .expect_worker_task_types()
1218            .return_const(WorkerTaskTypes {
1219                enable_workflows: true,
1220                enable_local_activities: true,
1221                enable_remote_activities: true,
1222                enable_nexus: false,
1223            });
1224
1225        manager
1226            .register_worker(Arc::new(worker1), false)
1227            .expect("first worker should register");
1228
1229        let result = manager.register_worker(Arc::new(worker2), false);
1230        assert!(result.is_err());
1231        assert!(
1232            result
1233                .unwrap_err()
1234                .to_string()
1235                .contains("overlapping worker task types")
1236        );
1237
1238        // activity-only worker
1239        let mut worker3 = MockClientWorker::new();
1240        worker3.expect_namespace().return_const(namespace.clone());
1241        worker3.expect_task_queue().return_const(task_queue.clone());
1242        worker3.expect_deployment_options().return_const(None);
1243        worker3
1244            .expect_worker_instance_key()
1245            .return_const(Uuid::new_v4());
1246        worker3.expect_heartbeat_enabled().return_const(false);
1247        worker3
1248            .expect_worker_task_types()
1249            .return_const(WorkerTaskTypes {
1250                enable_workflows: false,
1251                enable_local_activities: false,
1252                enable_remote_activities: true,
1253                enable_nexus: false,
1254            });
1255
1256        let result = manager.register_worker(Arc::new(worker3), false);
1257        assert!(result.is_err());
1258        assert!(
1259            result
1260                .unwrap_err()
1261                .to_string()
1262                .contains("overlapping worker task types")
1263        );
1264    }
1265
1266    #[test]
1267    fn wft_slot_reservation_ignores_non_workflow_workers() {
1268        let mut manager_impl = ClientWorkerSetImpl::new();
1269        let namespace = "test_namespace".to_string();
1270        let task_queue = "test_queue".to_string();
1271
1272        let mut activity_worker = MockClientWorker::new();
1273        activity_worker
1274            .expect_namespace()
1275            .return_const(namespace.clone());
1276        activity_worker
1277            .expect_task_queue()
1278            .return_const(task_queue.clone());
1279        activity_worker
1280            .expect_deployment_options()
1281            .return_const(None);
1282        activity_worker
1283            .expect_worker_instance_key()
1284            .return_const(Uuid::new_v4());
1285        activity_worker
1286            .expect_heartbeat_enabled()
1287            .return_const(false);
1288        activity_worker
1289            .expect_worker_task_types()
1290            .return_const(WorkerTaskTypes {
1291                enable_workflows: false,
1292                enable_local_activities: false,
1293                enable_remote_activities: true,
1294                enable_nexus: false,
1295            });
1296
1297        let mut nexus_worker = MockClientWorker::new();
1298        nexus_worker
1299            .expect_namespace()
1300            .return_const(namespace.clone());
1301        nexus_worker
1302            .expect_task_queue()
1303            .return_const(task_queue.clone());
1304        nexus_worker.expect_deployment_options().return_const(None);
1305        nexus_worker
1306            .expect_worker_instance_key()
1307            .return_const(Uuid::new_v4());
1308        nexus_worker.expect_heartbeat_enabled().return_const(false);
1309        nexus_worker
1310            .expect_worker_task_types()
1311            .return_const(WorkerTaskTypes {
1312                enable_workflows: false,
1313                enable_local_activities: false,
1314                enable_remote_activities: false,
1315                enable_nexus: true,
1316            });
1317
1318        manager_impl
1319            .register(Arc::new(activity_worker), false)
1320            .expect("activity worker should register");
1321        manager_impl
1322            .register(Arc::new(nexus_worker), false)
1323            .expect("nexus worker should register");
1324
1325        let reservation = manager_impl.try_reserve_wft_slot(namespace.clone(), task_queue.clone());
1326        assert!(
1327            reservation.is_none(),
1328            "should not find workflow workers when only activity/nexus workers registered"
1329        );
1330
1331        // Now register a workflow worker
1332        let mut workflow_worker = MockClientWorker::new();
1333        workflow_worker
1334            .expect_namespace()
1335            .return_const(namespace.clone());
1336        workflow_worker
1337            .expect_task_queue()
1338            .return_const(task_queue.clone());
1339        workflow_worker
1340            .expect_deployment_options()
1341            .return_const(None);
1342        workflow_worker
1343            .expect_worker_instance_key()
1344            .return_const(Uuid::new_v4());
1345        workflow_worker
1346            .expect_heartbeat_enabled()
1347            .return_const(false);
1348        workflow_worker
1349            .expect_worker_task_types()
1350            .return_const(WorkerTaskTypes {
1351                enable_workflows: true,
1352                enable_local_activities: true,
1353                enable_remote_activities: false,
1354                enable_nexus: false,
1355            });
1356        workflow_worker
1357            .expect_try_reserve_wft_slot()
1358            .times(1)
1359            .returning(|| Some(new_mock_slot(false)));
1360
1361        manager_impl
1362            .register(Arc::new(workflow_worker), false)
1363            .expect("workflow worker should register");
1364
1365        let reservation = manager_impl.try_reserve_wft_slot(namespace.clone(), task_queue.clone());
1366        assert!(
1367            reservation.is_some(),
1368            "should find workflow worker after it's registered"
1369        );
1370    }
1371
1372    #[test]
1373    fn worker_invalid_type_config_rejected() {
1374        let manager = ClientWorkerSet::new();
1375
1376        // no types enabled
1377        let mut worker = MockClientWorker::new();
1378        worker
1379            .expect_namespace()
1380            .return_const("test_namespace".to_string());
1381        worker
1382            .expect_task_queue()
1383            .return_const("test_queue".to_string());
1384        worker.expect_deployment_options().return_const(None);
1385        worker
1386            .expect_worker_instance_key()
1387            .return_const(Uuid::new_v4());
1388        worker.expect_heartbeat_enabled().return_const(false);
1389        worker
1390            .expect_worker_task_types()
1391            .return_const(WorkerTaskTypes {
1392                enable_workflows: false,
1393                enable_local_activities: false,
1394                enable_remote_activities: false,
1395                enable_nexus: false,
1396            });
1397
1398        let result = manager.register_worker(Arc::new(worker), false);
1399        assert!(result.is_err());
1400        assert!(
1401            result
1402                .unwrap_err()
1403                .to_string()
1404                .contains("must have at least one capability enabled")
1405        );
1406
1407        // local activities enabled without workflows
1408        let mut worker = MockClientWorker::new();
1409        worker
1410            .expect_namespace()
1411            .return_const("test_namespace".to_string());
1412        worker
1413            .expect_task_queue()
1414            .return_const("test_queue".to_string());
1415        worker.expect_deployment_options().return_const(None);
1416        worker
1417            .expect_worker_instance_key()
1418            .return_const(Uuid::new_v4());
1419        worker.expect_heartbeat_enabled().return_const(false);
1420        worker
1421            .expect_worker_task_types()
1422            .return_const(WorkerTaskTypes {
1423                enable_workflows: false,
1424                enable_local_activities: true,
1425                enable_remote_activities: true,
1426                enable_nexus: false,
1427            });
1428
1429        let result = manager.register_worker(Arc::new(worker), false);
1430        assert!(result.is_err());
1431        assert_eq!(
1432            result.unwrap_err().to_string(),
1433            "Local activities cannot be enabled without workflows".to_string()
1434        );
1435    }
1436
1437    #[test]
1438    fn unregister_with_multiple_workers() {
1439        let manager = ClientWorkerSet::new();
1440        let namespace = "test_namespace".to_string();
1441        let task_queue = "test_queue".to_string();
1442
1443        // workflow-only worker
1444        let mut workflow_worker = MockClientWorker::new();
1445        workflow_worker
1446            .expect_namespace()
1447            .return_const(namespace.clone());
1448        workflow_worker
1449            .expect_task_queue()
1450            .return_const(task_queue.clone());
1451        workflow_worker
1452            .expect_deployment_options()
1453            .return_const(None);
1454        let wf_worker_key = Uuid::new_v4();
1455        workflow_worker
1456            .expect_worker_instance_key()
1457            .return_const(wf_worker_key);
1458        workflow_worker
1459            .expect_heartbeat_enabled()
1460            .return_const(false);
1461        workflow_worker
1462            .expect_worker_task_types()
1463            .return_const(WorkerTaskTypes {
1464                enable_workflows: true,
1465                enable_local_activities: true,
1466                enable_remote_activities: false,
1467                enable_nexus: false,
1468            });
1469        workflow_worker
1470            .expect_try_reserve_wft_slot()
1471            .returning(|| Some(new_mock_slot(false)));
1472
1473        // activity-only worker
1474        let mut activity_worker = MockClientWorker::new();
1475        activity_worker
1476            .expect_namespace()
1477            .return_const(namespace.clone());
1478        activity_worker
1479            .expect_task_queue()
1480            .return_const(task_queue.clone());
1481        activity_worker
1482            .expect_deployment_options()
1483            .return_const(None);
1484        let act_worker_key = Uuid::new_v4();
1485        activity_worker
1486            .expect_worker_instance_key()
1487            .return_const(act_worker_key);
1488        activity_worker
1489            .expect_heartbeat_enabled()
1490            .return_const(false);
1491        activity_worker
1492            .expect_worker_task_types()
1493            .return_const(WorkerTaskTypes {
1494                enable_workflows: false,
1495                enable_local_activities: false,
1496                enable_remote_activities: true,
1497                enable_nexus: false,
1498            });
1499
1500        manager
1501            .register_worker(Arc::new(workflow_worker), false)
1502            .expect("workflow worker should register");
1503        manager
1504            .register_worker(Arc::new(activity_worker), false)
1505            .expect("activity worker should register");
1506
1507        assert_eq!(2, manager.num_providers());
1508
1509        let reservation = manager.try_reserve_wft_slot(namespace.clone(), task_queue.clone());
1510        assert!(
1511            reservation.is_some(),
1512            "should be able to reserve slot from workflow worker"
1513        );
1514
1515        manager
1516            .unregister_slot_provider(wf_worker_key)
1517            .expect("should unregister slot provider for workflow worker");
1518        manager
1519            .finalize_unregister(wf_worker_key)
1520            .expect("should finalize unregister for workflow worker");
1521
1522        // Activity worker should still be registered
1523        assert_eq!(1, manager.num_providers());
1524
1525        let reservation = manager.try_reserve_wft_slot(namespace.clone(), task_queue.clone());
1526        assert!(
1527            reservation.is_none(),
1528            "should not find workflow worker after unregistration"
1529        );
1530
1531        manager
1532            .unregister_slot_provider(act_worker_key)
1533            .expect("should unregister slot provider for activity worker");
1534        manager
1535            .finalize_unregister(act_worker_key)
1536            .expect("should finalize unregister for activity worker");
1537
1538        assert_eq!(0, manager.num_providers());
1539    }
1540
1541    #[test]
1542    fn worker_unregister_order() {
1543        let manager = ClientWorkerSet::new();
1544        let worker = new_mock_provider_with_heartbeat(
1545            "namespace1".to_string(),
1546            "queue1".to_string(),
1547            true,
1548            None,
1549        );
1550        let worker_instance_key = worker.worker_instance_key();
1551        manager.register_worker(Arc::new(worker), false).unwrap();
1552
1553        let res = manager.finalize_unregister(worker_instance_key);
1554        assert!(res.is_err());
1555        let err_string = res.err().map(|e| e.to_string()).unwrap();
1556        assert!(err_string.contains("Worker still in slot_providers during finalize"));
1557
1558        // previous incorrect call to finalize_unregister should not cause any state leaks when
1559        // properly removed later
1560        manager
1561            .unregister_slot_provider(worker_instance_key)
1562            .unwrap();
1563        manager.finalize_unregister(worker_instance_key).unwrap();
1564    }
1565}