Skip to main content

rust_pipe/worker/
mod.rs

1/// Worker pool management and least-loaded selection.
2use chrono::{DateTime, Utc};
3use dashmap::DashMap;
4use serde::{Deserialize, Serialize};
5use std::sync::Arc;
6
7use crate::transport::WorkerLanguage;
8
9#[derive(Debug, Clone, Serialize, Deserialize)]
10/// Information about a connected worker.
11pub struct WorkerInfo {
12    pub id: String,
13    pub language: WorkerLanguage,
14    pub supported_tasks: Vec<String>,
15    pub max_concurrency: u32,
16    pub status: WorkerStatus,
17    pub active_tasks: u32,
18    pub registered_at: DateTime<Utc>,
19    pub last_heartbeat: DateTime<Utc>,
20    /// Tags/metadata for targeted dispatch routing.
21    #[serde(default)]
22    pub tags: Vec<String>,
23}
24
25#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq)]
26/// Current status of a worker in the pool.
27pub enum WorkerStatus {
28    Active,
29    Busy,
30    Draining,
31    Dead,
32}
33
34/// Errors related to pool management operations.
35#[derive(Debug, thiserror::Error, PartialEq)]
36pub enum PoolError {
37    #[error("Pool is at maximum capacity ({max}), cannot register worker")]
38    PoolFull { max: u32 },
39
40    #[error("Worker not found: {worker_id}")]
41    WorkerNotFound { worker_id: String },
42
43    #[error("Worker '{worker_id}' is at capacity")]
44    WorkerAtCapacity { worker_id: String },
45
46    #[error("Worker '{worker_id}' is draining or dead")]
47    WorkerUnavailable { worker_id: String },
48}
49
50/// Pool of connected workers with capacity-aware task routing.
51pub struct WorkerPool {
52    workers: DashMap<String, WorkerInfo>,
53    heartbeat_timeout_ms: u64,
54    max_pool_size: Option<u32>,
55    min_pool_size: Option<u32>,
56    on_pool_below_min: Option<Arc<dyn Fn(u32) + Send + Sync>>,
57}
58
59impl WorkerPool {
60    pub fn new(heartbeat_timeout_ms: u64) -> Self {
61        Self {
62            workers: DashMap::new(),
63            heartbeat_timeout_ms,
64            max_pool_size: None,
65            min_pool_size: None,
66            on_pool_below_min: None,
67        }
68    }
69
70    pub fn with_limits(
71        heartbeat_timeout_ms: u64,
72        max_pool_size: Option<u32>,
73        min_pool_size: Option<u32>,
74        on_pool_below_min: Option<Arc<dyn Fn(u32) + Send + Sync>>,
75    ) -> Self {
76        Self {
77            workers: DashMap::new(),
78            heartbeat_timeout_ms,
79            max_pool_size,
80            min_pool_size,
81            on_pool_below_min,
82        }
83    }
84
85    /// Register a worker. If the pool is at max capacity (and this is not a
86    /// re-registration of an existing worker), the registration is silently rejected.
87    /// Use [`try_register`](Self::try_register) for explicit error handling.
88    pub fn register(&self, info: WorkerInfo) {
89        if let Err(e) = self.try_register(info) {
90            tracing::warn!(error = %e, "Worker registration rejected");
91        }
92    }
93
94    /// Register a worker, returning an error if the pool is at max capacity.
95    /// Re-registration of an existing worker (same ID) always succeeds (updates in place).
96    pub fn try_register(&self, info: WorkerInfo) -> Result<(), PoolError> {
97        if let Some(max) = self.max_pool_size {
98            let is_reregistration = self.workers.contains_key(&info.id);
99            if !is_reregistration && self.workers.len() as u32 >= max {
100                return Err(PoolError::PoolFull { max });
101            }
102        }
103        tracing::info!(
104            worker_id = %info.id,
105            language = ?info.language,
106            tasks = ?info.supported_tasks,
107            "Registering worker"
108        );
109        self.workers.insert(info.id.clone(), info);
110        Ok(())
111    }
112
113    pub fn deregister(&self, worker_id: &str) {
114        self.workers.remove(worker_id);
115        self.check_below_min();
116    }
117
118    pub fn heartbeat(&self, worker_id: &str, active_tasks: u32) {
119        if let Some(mut worker) = self.workers.get_mut(worker_id) {
120            worker.last_heartbeat = Utc::now();
121            worker.active_tasks = active_tasks;
122            // Don't overwrite Draining status — drain must be explicit
123            if worker.status != WorkerStatus::Draining {
124                worker.status = if active_tasks >= worker.max_concurrency {
125                    WorkerStatus::Busy
126                } else {
127                    WorkerStatus::Active
128                };
129            }
130        }
131    }
132
133    /// Atomically selects a worker and reserves capacity.
134    /// Returns the worker ID if one is available, or None.
135    /// This avoids the TOCTOU race between select and dispatch.
136    pub fn select_and_reserve(&self, task_type: &str) -> Option<String> {
137        let mut best_id: Option<String> = None;
138        let mut best_capacity: u32 = 0;
139
140        for entry in self.workers.iter() {
141            let worker = entry.value();
142            if worker.status == WorkerStatus::Dead || worker.status == WorkerStatus::Draining {
143                continue;
144            }
145            if !worker.supported_tasks.iter().any(|t| t == task_type) {
146                continue;
147            }
148            if worker.active_tasks >= worker.max_concurrency {
149                continue;
150            }
151
152            let available = worker.max_concurrency - worker.active_tasks;
153            if best_id.is_none() || available > best_capacity {
154                best_id = Some(worker.id.clone());
155                best_capacity = available;
156            }
157        }
158
159        let worker_id = best_id?;
160
161        // Atomically increment under the entry lock
162        if let Some(mut worker) = self.workers.get_mut(&worker_id) {
163            if worker.active_tasks >= worker.max_concurrency {
164                return None;
165            }
166            worker.active_tasks += 1;
167            if worker.active_tasks >= worker.max_concurrency {
168                worker.status = WorkerStatus::Busy;
169            }
170            Some(worker_id)
171        } else {
172            None
173        }
174    }
175
176    /// Selects the least-loaded worker for a task type WITHOUT modifying state.
177    /// Use `select_and_reserve` for dispatch to avoid TOCTOU races.
178    pub fn select_worker(&self, task_type: &str) -> Option<String> {
179        let mut best_id: Option<String> = None;
180        let mut best_capacity: u32 = 0;
181
182        for entry in self.workers.iter() {
183            let worker = entry.value();
184            if worker.status == WorkerStatus::Dead || worker.status == WorkerStatus::Draining {
185                continue;
186            }
187            if !worker.supported_tasks.iter().any(|t| t == task_type) {
188                continue;
189            }
190            if worker.active_tasks >= worker.max_concurrency {
191                continue;
192            }
193
194            let available = worker.max_concurrency - worker.active_tasks;
195            if best_id.is_none() || available > best_capacity {
196                best_id = Some(worker.id.clone());
197                best_capacity = available;
198            }
199        }
200
201        best_id
202    }
203
204    pub fn mark_task_dispatched(&self, worker_id: &str) {
205        if let Some(mut worker) = self.workers.get_mut(worker_id) {
206            worker.active_tasks += 1;
207            if worker.active_tasks >= worker.max_concurrency {
208                worker.status = WorkerStatus::Busy;
209            }
210        }
211    }
212
213    pub fn mark_task_completed(&self, worker_id: &str) {
214        if let Some(mut worker) = self.workers.get_mut(worker_id) {
215            worker.active_tasks = worker.active_tasks.saturating_sub(1);
216            if worker.active_tasks < worker.max_concurrency && worker.status == WorkerStatus::Busy {
217                worker.status = WorkerStatus::Active;
218            }
219        }
220    }
221
222    pub fn detect_dead_workers(&self) -> Vec<String> {
223        let now = Utc::now();
224        let mut dead = Vec::new();
225
226        for mut entry in self.workers.iter_mut() {
227            let elapsed_ms = (now - entry.last_heartbeat).num_milliseconds().max(0) as u64;
228            if elapsed_ms > self.heartbeat_timeout_ms {
229                entry.status = WorkerStatus::Dead;
230                dead.push(entry.id.clone());
231            }
232        }
233
234        dead
235    }
236
237    pub fn active_workers(&self) -> Vec<WorkerInfo> {
238        self.workers
239            .iter()
240            .filter(|w| w.status == WorkerStatus::Active || w.status == WorkerStatus::Busy)
241            .map(|w| w.value().clone())
242            .collect()
243    }
244
245    pub fn count(&self) -> usize {
246        self.workers.len()
247    }
248
249    pub fn stats(&self) -> PoolStats {
250        let mut stats = PoolStats::default();
251        for entry in self.workers.iter() {
252            stats.total += 1;
253            match entry.status {
254                WorkerStatus::Active => stats.active += 1,
255                WorkerStatus::Busy => stats.busy += 1,
256                WorkerStatus::Draining => stats.draining += 1,
257                WorkerStatus::Dead => stats.dead += 1,
258            }
259            stats.total_capacity += entry.max_concurrency;
260            stats.used_capacity += entry.active_tasks;
261        }
262        stats
263    }
264
265    /// List all connected workers with their full info.
266    pub fn workers(&self) -> Vec<WorkerInfo> {
267        self.workers.iter().map(|w| w.value().clone()).collect()
268    }
269
270    /// Set a worker's status to Draining so no new tasks are routed to it.
271    /// Existing tasks will finish normally.
272    pub fn drain_worker(&self, worker_id: &str) -> Result<(), PoolError> {
273        if let Some(mut worker) = self.workers.get_mut(worker_id) {
274            worker.status = WorkerStatus::Draining;
275            tracing::info!(worker_id = %worker_id, "Worker set to draining");
276            Ok(())
277        } else {
278            Err(PoolError::WorkerNotFound {
279                worker_id: worker_id.to_string(),
280            })
281        }
282    }
283
284    /// Force-remove a worker from the pool. Returns the list of pending task IDs
285    /// that were assigned to this worker (caller is responsible for failing them).
286    pub fn remove_worker(&self, worker_id: &str) -> Result<(), PoolError> {
287        if self.workers.remove(worker_id).is_some() {
288            tracing::info!(worker_id = %worker_id, "Worker force-removed from pool");
289            self.check_below_min();
290            Ok(())
291        } else {
292            Err(PoolError::WorkerNotFound {
293                worker_id: worker_id.to_string(),
294            })
295        }
296    }
297
298    /// Select a worker that has a matching tag and reserve capacity on it.
299    pub fn select_and_reserve_with_tag(&self, tag: &str, task_type: &str) -> Option<String> {
300        let mut best_id: Option<String> = None;
301        let mut best_capacity: u32 = 0;
302
303        for entry in self.workers.iter() {
304            let worker = entry.value();
305            if worker.status == WorkerStatus::Dead || worker.status == WorkerStatus::Draining {
306                continue;
307            }
308            if !worker.tags.iter().any(|t| t == tag) {
309                continue;
310            }
311            if !worker.supported_tasks.iter().any(|t| t == task_type) {
312                continue;
313            }
314            if worker.active_tasks >= worker.max_concurrency {
315                continue;
316            }
317
318            let available = worker.max_concurrency - worker.active_tasks;
319            if best_id.is_none() || available > best_capacity {
320                best_id = Some(worker.id.clone());
321                best_capacity = available;
322            }
323        }
324
325        let worker_id = best_id?;
326
327        // Atomically increment under the entry lock
328        if let Some(mut worker) = self.workers.get_mut(&worker_id) {
329            if worker.active_tasks >= worker.max_concurrency {
330                return None;
331            }
332            worker.active_tasks += 1;
333            if worker.active_tasks >= worker.max_concurrency {
334                worker.status = WorkerStatus::Busy;
335            }
336            Some(worker_id)
337        } else {
338            None
339        }
340    }
341
342    /// Reserve capacity on a specific worker by ID.
343    /// Returns Ok(()) if capacity was reserved, or a specific error explaining why not.
344    pub fn reserve_specific_worker(&self, worker_id: &str) -> Result<(), PoolError> {
345        if let Some(mut worker) = self.workers.get_mut(worker_id) {
346            if worker.status == WorkerStatus::Dead || worker.status == WorkerStatus::Draining {
347                return Err(PoolError::WorkerUnavailable {
348                    worker_id: worker_id.to_string(),
349                });
350            }
351            if worker.active_tasks >= worker.max_concurrency {
352                return Err(PoolError::WorkerAtCapacity {
353                    worker_id: worker_id.to_string(),
354                });
355            }
356            worker.active_tasks += 1;
357            if worker.active_tasks >= worker.max_concurrency {
358                worker.status = WorkerStatus::Busy;
359            }
360            Ok(())
361        } else {
362            Err(PoolError::WorkerNotFound {
363                worker_id: worker_id.to_string(),
364            })
365        }
366    }
367
368    /// Check if the pool is below the minimum size and fire the callback if so.
369    fn check_below_min(&self) {
370        if let Some(min) = self.min_pool_size {
371            let current = self.workers.len() as u32;
372            if current < min {
373                if let Some(ref cb) = self.on_pool_below_min {
374                    cb(current);
375                }
376            }
377        }
378    }
379}
380
381#[derive(Debug, Default, Serialize)]
382/// Aggregate statistics about the worker pool.
383pub struct PoolStats {
384    pub total: u32,
385    pub active: u32,
386    pub busy: u32,
387    pub draining: u32,
388    pub dead: u32,
389    pub total_capacity: u32,
390    pub used_capacity: u32,
391}
392
393#[cfg(test)]
394mod tests {
395    use super::*;
396    use crate::transport::WorkerLanguage;
397
398    fn make_worker(id: &str, tasks: Vec<&str>, max_concurrency: u32) -> WorkerInfo {
399        WorkerInfo {
400            id: id.to_string(),
401            language: WorkerLanguage::TypeScript,
402            supported_tasks: tasks.into_iter().map(String::from).collect(),
403            max_concurrency,
404            status: WorkerStatus::Active,
405            active_tasks: 0,
406            registered_at: Utc::now(),
407            last_heartbeat: Utc::now(),
408            tags: vec![],
409        }
410    }
411
412    fn make_tagged_worker(
413        id: &str,
414        tasks: Vec<&str>,
415        max_concurrency: u32,
416        tags: Vec<&str>,
417    ) -> WorkerInfo {
418        WorkerInfo {
419            id: id.to_string(),
420            language: WorkerLanguage::TypeScript,
421            supported_tasks: tasks.into_iter().map(String::from).collect(),
422            max_concurrency,
423            status: WorkerStatus::Active,
424            active_tasks: 0,
425            registered_at: Utc::now(),
426            last_heartbeat: Utc::now(),
427            tags: tags.into_iter().map(String::from).collect(),
428        }
429    }
430
431    #[test]
432    fn test_pool_new_empty() {
433        let pool = WorkerPool::new(15_000);
434        assert_eq!(pool.count(), 0);
435    }
436
437    #[test]
438    fn test_register_single_worker() {
439        let pool = WorkerPool::new(15_000);
440        pool.register(make_worker("w1", vec!["task-a"], 5));
441        assert_eq!(pool.count(), 1);
442    }
443
444    #[test]
445    fn test_register_multiple_workers() {
446        let pool = WorkerPool::new(15_000);
447        pool.register(make_worker("w1", vec!["a"], 5));
448        pool.register(make_worker("w2", vec!["b"], 5));
449        pool.register(make_worker("w3", vec!["c"], 5));
450        assert_eq!(pool.count(), 3);
451    }
452
453    #[test]
454    fn test_register_overwrites_same_id() {
455        let pool = WorkerPool::new(15_000);
456        pool.register(make_worker("w1", vec!["a"], 5));
457        pool.register(make_worker("w1", vec!["b", "c"], 10));
458        assert_eq!(pool.count(), 1);
459    }
460
461    #[test]
462    fn test_deregister_existing_worker() {
463        let pool = WorkerPool::new(15_000);
464        pool.register(make_worker("w1", vec!["a"], 5));
465        pool.deregister("w1");
466        assert_eq!(pool.count(), 0);
467    }
468
469    #[test]
470    fn test_deregister_nonexistent_worker() {
471        let pool = WorkerPool::new(15_000);
472        pool.deregister("ghost");
473        assert_eq!(pool.count(), 0);
474    }
475
476    #[test]
477    fn test_heartbeat_updates_active_tasks() {
478        let pool = WorkerPool::new(15_000);
479        pool.register(make_worker("w1", vec!["a"], 5));
480        pool.heartbeat("w1", 3);
481        let stats = pool.stats();
482        assert_eq!(stats.used_capacity, 3);
483    }
484
485    #[test]
486    fn test_heartbeat_sets_busy_when_at_capacity() {
487        let pool = WorkerPool::new(15_000);
488        pool.register(make_worker("w1", vec!["a"], 2));
489        pool.heartbeat("w1", 2);
490        let stats = pool.stats();
491        assert_eq!(stats.busy, 1);
492        assert_eq!(stats.active, 0);
493    }
494
495    #[test]
496    fn test_heartbeat_sets_active_when_below_capacity() {
497        let pool = WorkerPool::new(15_000);
498        pool.register(make_worker("w1", vec!["a"], 5));
499        pool.heartbeat("w1", 3);
500        let stats = pool.stats();
501        assert_eq!(stats.active, 1);
502    }
503
504    #[test]
505    fn test_heartbeat_nonexistent_worker_is_noop() {
506        let pool = WorkerPool::new(15_000);
507        pool.heartbeat("ghost", 1);
508        assert_eq!(pool.count(), 0);
509    }
510
511    #[test]
512    fn test_select_worker_single_matching() {
513        let pool = WorkerPool::new(15_000);
514        pool.register(make_worker("w1", vec!["build"], 5));
515        assert_eq!(pool.select_worker("build"), Some("w1".to_string()));
516    }
517
518    #[test]
519    fn test_select_worker_returns_none_when_no_matching_type() {
520        let pool = WorkerPool::new(15_000);
521        pool.register(make_worker("w1", vec!["build"], 5));
522        assert_eq!(pool.select_worker("deploy"), None);
523    }
524
525    #[test]
526    fn test_select_worker_returns_none_when_pool_empty() {
527        let pool = WorkerPool::new(15_000);
528        assert_eq!(pool.select_worker("any"), None);
529    }
530
531    #[test]
532    fn test_select_worker_picks_least_loaded() {
533        let pool = WorkerPool::new(15_000);
534        let mut w1 = make_worker("w1", vec!["build"], 5);
535        w1.active_tasks = 4;
536        let mut w2 = make_worker("w2", vec!["build"], 5);
537        w2.active_tasks = 1;
538        pool.register(w1);
539        pool.register(w2);
540        assert_eq!(pool.select_worker("build"), Some("w2".to_string()));
541    }
542
543    #[test]
544    fn test_select_worker_skips_dead_worker() {
545        let pool = WorkerPool::new(15_000);
546        let mut w = make_worker("w1", vec!["build"], 5);
547        w.status = WorkerStatus::Dead;
548        pool.register(w);
549        assert_eq!(pool.select_worker("build"), None);
550    }
551
552    #[test]
553    fn test_select_worker_skips_draining_worker() {
554        let pool = WorkerPool::new(15_000);
555        let mut w = make_worker("w1", vec!["build"], 5);
556        w.status = WorkerStatus::Draining;
557        pool.register(w);
558        assert_eq!(pool.select_worker("build"), None);
559    }
560
561    #[test]
562    fn test_select_worker_skips_at_capacity() {
563        let pool = WorkerPool::new(15_000);
564        let mut w = make_worker("w1", vec!["build"], 2);
565        w.active_tasks = 2;
566        pool.register(w);
567        assert_eq!(pool.select_worker("build"), None);
568    }
569
570    #[test]
571    fn test_select_worker_multiple_task_types() {
572        let pool = WorkerPool::new(15_000);
573        pool.register(make_worker("w1", vec!["a", "b", "c"], 5));
574        assert_eq!(pool.select_worker("b"), Some("w1".to_string()));
575    }
576
577    #[test]
578    fn test_mark_task_dispatched_increments() {
579        let pool = WorkerPool::new(15_000);
580        pool.register(make_worker("w1", vec!["a"], 5));
581        pool.mark_task_dispatched("w1");
582        let stats = pool.stats();
583        assert_eq!(stats.used_capacity, 1);
584    }
585
586    #[test]
587    fn test_mark_task_dispatched_sets_busy_at_capacity() {
588        let pool = WorkerPool::new(15_000);
589        pool.register(make_worker("w1", vec!["a"], 1));
590        pool.mark_task_dispatched("w1");
591        let stats = pool.stats();
592        assert_eq!(stats.busy, 1);
593    }
594
595    #[test]
596    fn test_mark_task_dispatched_nonexistent_is_noop() {
597        let pool = WorkerPool::new(15_000);
598        pool.mark_task_dispatched("ghost");
599    }
600
601    #[test]
602    fn test_mark_task_completed_decrements() {
603        let pool = WorkerPool::new(15_000);
604        let mut w = make_worker("w1", vec!["a"], 5);
605        w.active_tasks = 2;
606        pool.register(w);
607        pool.mark_task_completed("w1");
608        let stats = pool.stats();
609        assert_eq!(stats.used_capacity, 1);
610    }
611
612    #[test]
613    fn test_mark_task_completed_saturating_at_zero() {
614        let pool = WorkerPool::new(15_000);
615        pool.register(make_worker("w1", vec!["a"], 5));
616        pool.mark_task_completed("w1");
617        let stats = pool.stats();
618        assert_eq!(stats.used_capacity, 0);
619    }
620
621    #[test]
622    fn test_mark_task_completed_transitions_busy_to_active() {
623        let pool = WorkerPool::new(15_000);
624        let mut w = make_worker("w1", vec!["a"], 2);
625        w.active_tasks = 2;
626        w.status = WorkerStatus::Busy;
627        pool.register(w);
628        pool.mark_task_completed("w1");
629        let stats = pool.stats();
630        assert_eq!(stats.active, 1);
631        assert_eq!(stats.busy, 0);
632    }
633
634    #[test]
635    fn test_detect_dead_workers_marks_stale() {
636        let pool = WorkerPool::new(100); // 100ms timeout
637        let mut w = make_worker("w1", vec!["a"], 5);
638        w.last_heartbeat = Utc::now() - chrono::Duration::seconds(1);
639        pool.register(w);
640        let dead = pool.detect_dead_workers();
641        assert_eq!(dead, vec!["w1".to_string()]);
642    }
643
644    #[test]
645    fn test_detect_dead_workers_spares_fresh() {
646        let pool = WorkerPool::new(15_000);
647        pool.register(make_worker("w1", vec!["a"], 5));
648        let dead = pool.detect_dead_workers();
649        assert!(dead.is_empty());
650    }
651
652    #[test]
653    fn test_detect_dead_workers_empty_pool() {
654        let pool = WorkerPool::new(15_000);
655        let dead = pool.detect_dead_workers();
656        assert!(dead.is_empty());
657    }
658
659    #[test]
660    fn test_stats_empty_pool() {
661        let pool = WorkerPool::new(15_000);
662        let stats = pool.stats();
663        assert_eq!(stats.total, 0);
664        assert_eq!(stats.active, 0);
665        assert_eq!(stats.total_capacity, 0);
666    }
667
668    #[test]
669    fn test_stats_counts_all_statuses() {
670        let pool = WorkerPool::new(15_000);
671        let mut w1 = make_worker("w1", vec!["a"], 5);
672        w1.status = WorkerStatus::Active;
673        let mut w2 = make_worker("w2", vec!["a"], 5);
674        w2.status = WorkerStatus::Busy;
675        let mut w3 = make_worker("w3", vec!["a"], 5);
676        w3.status = WorkerStatus::Draining;
677        let mut w4 = make_worker("w4", vec!["a"], 5);
678        w4.status = WorkerStatus::Dead;
679        pool.register(w1);
680        pool.register(w2);
681        pool.register(w3);
682        pool.register(w4);
683        let stats = pool.stats();
684        assert_eq!(stats.total, 4);
685        assert_eq!(stats.active, 1);
686        assert_eq!(stats.busy, 1);
687        assert_eq!(stats.draining, 1);
688        assert_eq!(stats.dead, 1);
689    }
690
691    #[test]
692    fn test_stats_capacity_tracking() {
693        let pool = WorkerPool::new(15_000);
694        let mut w = make_worker("w1", vec!["a"], 10);
695        w.active_tasks = 3;
696        pool.register(w);
697        let stats = pool.stats();
698        assert_eq!(stats.total_capacity, 10);
699        assert_eq!(stats.used_capacity, 3);
700    }
701
702    #[test]
703    fn test_active_workers_includes_active_and_busy() {
704        let pool = WorkerPool::new(15_000);
705        let mut w1 = make_worker("w1", vec!["a"], 5);
706        w1.status = WorkerStatus::Active;
707        let mut w2 = make_worker("w2", vec!["a"], 5);
708        w2.status = WorkerStatus::Busy;
709        pool.register(w1);
710        pool.register(w2);
711        assert_eq!(pool.active_workers().len(), 2);
712    }
713
714    #[test]
715    fn test_active_workers_excludes_dead_and_draining() {
716        let pool = WorkerPool::new(15_000);
717        let mut w1 = make_worker("w1", vec!["a"], 5);
718        w1.status = WorkerStatus::Dead;
719        let mut w2 = make_worker("w2", vec!["a"], 5);
720        w2.status = WorkerStatus::Draining;
721        pool.register(w1);
722        pool.register(w2);
723        assert_eq!(pool.active_workers().len(), 0);
724    }
725
726    // =========================================================================
727    // Pool size limits tests
728    // =========================================================================
729
730    #[test]
731    fn test_max_pool_size_rejects_over_limit() {
732        let pool = WorkerPool::with_limits(15_000, Some(2), None, None);
733        pool.register(make_worker("w1", vec!["a"], 5));
734        pool.register(make_worker("w2", vec!["a"], 5));
735        pool.register(make_worker("w3", vec!["a"], 5)); // should be rejected
736        assert_eq!(pool.count(), 2);
737    }
738
739    #[test]
740    fn test_max_pool_size_none_allows_unlimited() {
741        let pool = WorkerPool::with_limits(15_000, None, None, None);
742        for i in 0..100 {
743            pool.register(make_worker(&format!("w{}", i), vec!["a"], 5));
744        }
745        assert_eq!(pool.count(), 100);
746    }
747
748    #[test]
749    fn test_try_register_returns_error_on_full() {
750        let pool = WorkerPool::with_limits(15_000, Some(1), None, None);
751        pool.register(make_worker("w1", vec!["a"], 5));
752        let result = pool.try_register(make_worker("w2", vec!["a"], 5));
753        assert_eq!(result, Err(PoolError::PoolFull { max: 1 }));
754    }
755
756    #[test]
757    fn test_try_register_succeeds_under_limit() {
758        let pool = WorkerPool::with_limits(15_000, Some(3), None, None);
759        let result = pool.try_register(make_worker("w1", vec!["a"], 5));
760        assert!(result.is_ok());
761        assert_eq!(pool.count(), 1);
762    }
763
764    #[test]
765    fn test_min_pool_size_does_not_fire_on_register() {
766        use std::sync::atomic::{AtomicU32, Ordering};
767        let call_count = Arc::new(AtomicU32::new(0));
768        let c = call_count.clone();
769        let pool = WorkerPool::with_limits(
770            15_000,
771            None,
772            Some(3),
773            Some(Arc::new(move |_current| {
774                c.fetch_add(1, Ordering::SeqCst);
775            })),
776        );
777        pool.register(make_worker("w1", vec!["a"], 5));
778        // Callback should NOT fire during registration — only on worker loss
779        assert_eq!(call_count.load(Ordering::SeqCst), 0);
780    }
781
782    #[test]
783    fn test_min_pool_size_fires_callback_on_deregister() {
784        use std::sync::atomic::{AtomicU32, Ordering};
785        let called_with = Arc::new(AtomicU32::new(999));
786        let called_clone = called_with.clone();
787        let pool = WorkerPool::with_limits(
788            15_000,
789            None,
790            Some(2),
791            Some(Arc::new(move |current| {
792                called_clone.store(current, Ordering::SeqCst);
793            })),
794        );
795        pool.register(make_worker("w1", vec!["a"], 5));
796        pool.register(make_worker("w2", vec!["a"], 5));
797        // Pool at min, callback was fired during register but with current=1 then current=2
798        // After second register, pool is at 2 = min, so no callback
799        // Now deregister one
800        pool.deregister("w1");
801        // Pool has 1, min is 2 => callback fires with 1
802        assert_eq!(called_with.load(Ordering::SeqCst), 1);
803    }
804
805    #[test]
806    fn test_min_pool_size_no_callback_when_above_min() {
807        use std::sync::atomic::{AtomicBool, Ordering};
808        let was_called = Arc::new(AtomicBool::new(false));
809        let was_called_clone = was_called.clone();
810        let pool = WorkerPool::with_limits(
811            15_000,
812            None,
813            Some(1),
814            Some(Arc::new(move |_| {
815                was_called_clone.store(true, Ordering::SeqCst);
816            })),
817        );
818        pool.register(make_worker("w1", vec!["a"], 5));
819        pool.register(make_worker("w2", vec!["a"], 5));
820        // Pool at 2, min 1 => second register should NOT fire callback
821        // But first register: pool goes to 1, which is NOT below min (1 < 1 is false)
822        assert!(!was_called.load(Ordering::SeqCst));
823    }
824
825    // =========================================================================
826    // Worker listing tests
827    // =========================================================================
828
829    #[test]
830    fn test_workers_returns_all() {
831        let pool = WorkerPool::new(15_000);
832        pool.register(make_worker("w1", vec!["a"], 5));
833        pool.register(make_worker("w2", vec!["b"], 3));
834        let workers = pool.workers();
835        assert_eq!(workers.len(), 2);
836        let ids: Vec<&str> = workers.iter().map(|w| w.id.as_str()).collect();
837        assert!(ids.contains(&"w1"));
838        assert!(ids.contains(&"w2"));
839    }
840
841    #[test]
842    fn test_workers_empty_pool() {
843        let pool = WorkerPool::new(15_000);
844        assert!(pool.workers().is_empty());
845    }
846
847    // =========================================================================
848    // Drain worker tests
849    // =========================================================================
850
851    #[test]
852    fn test_drain_worker_sets_status() {
853        let pool = WorkerPool::new(15_000);
854        pool.register(make_worker("w1", vec!["a"], 5));
855        pool.drain_worker("w1").unwrap();
856        let workers = pool.workers();
857        assert_eq!(workers[0].status, WorkerStatus::Draining);
858    }
859
860    #[test]
861    fn test_drain_worker_not_found() {
862        let pool = WorkerPool::new(15_000);
863        let err = pool.drain_worker("ghost").unwrap_err();
864        assert_eq!(
865            err,
866            PoolError::WorkerNotFound {
867                worker_id: "ghost".to_string()
868            }
869        );
870    }
871
872    #[test]
873    fn test_drain_worker_not_selected() {
874        let pool = WorkerPool::new(15_000);
875        pool.register(make_worker("w1", vec!["a"], 5));
876        pool.drain_worker("w1").unwrap();
877        // Draining workers should not be selected
878        assert_eq!(pool.select_worker("a"), None);
879        assert_eq!(pool.select_and_reserve("a"), None);
880    }
881
882    // =========================================================================
883    // Remove worker tests
884    // =========================================================================
885
886    #[test]
887    fn test_remove_worker_removes_from_pool() {
888        let pool = WorkerPool::new(15_000);
889        pool.register(make_worker("w1", vec!["a"], 5));
890        pool.remove_worker("w1").unwrap();
891        assert_eq!(pool.count(), 0);
892    }
893
894    #[test]
895    fn test_remove_worker_not_found() {
896        let pool = WorkerPool::new(15_000);
897        let err = pool.remove_worker("ghost").unwrap_err();
898        assert_eq!(
899            err,
900            PoolError::WorkerNotFound {
901                worker_id: "ghost".to_string()
902            }
903        );
904    }
905
906    #[test]
907    fn test_remove_worker_fires_below_min_callback() {
908        use std::sync::atomic::{AtomicU32, Ordering};
909        let called_with = Arc::new(AtomicU32::new(999));
910        let called_clone = called_with.clone();
911        let pool = WorkerPool::with_limits(
912            15_000,
913            None,
914            Some(2),
915            Some(Arc::new(move |current| {
916                called_clone.store(current, Ordering::SeqCst);
917            })),
918        );
919        pool.register(make_worker("w1", vec!["a"], 5));
920        pool.register(make_worker("w2", vec!["a"], 5));
921        pool.remove_worker("w1").unwrap();
922        assert_eq!(called_with.load(Ordering::SeqCst), 1);
923    }
924
925    // =========================================================================
926    // Worker tags and tag-based dispatch tests
927    // =========================================================================
928
929    #[test]
930    fn test_worker_info_tags_default_empty() {
931        let w = make_worker("w1", vec!["a"], 5);
932        assert!(w.tags.is_empty());
933    }
934
935    #[test]
936    fn test_worker_info_tags_set() {
937        let w = make_tagged_worker("w1", vec!["a"], 5, vec!["gpu", "us-east-1"]);
938        assert_eq!(w.tags, vec!["gpu".to_string(), "us-east-1".to_string()]);
939    }
940
941    #[test]
942    fn test_select_and_reserve_with_tag_finds_matching() {
943        let pool = WorkerPool::new(15_000);
944        pool.register(make_tagged_worker("w1", vec!["build"], 5, vec!["gpu"]));
945        pool.register(make_tagged_worker("w2", vec!["build"], 5, vec!["cpu"]));
946        let selected = pool.select_and_reserve_with_tag("gpu", "build");
947        assert_eq!(selected, Some("w1".to_string()));
948    }
949
950    #[test]
951    fn test_select_and_reserve_with_tag_none_when_no_match() {
952        let pool = WorkerPool::new(15_000);
953        pool.register(make_tagged_worker("w1", vec!["build"], 5, vec!["cpu"]));
954        let selected = pool.select_and_reserve_with_tag("gpu", "build");
955        assert_eq!(selected, None);
956    }
957
958    #[test]
959    fn test_select_and_reserve_with_tag_skips_draining() {
960        let pool = WorkerPool::new(15_000);
961        let mut w = make_tagged_worker("w1", vec!["build"], 5, vec!["gpu"]);
962        w.status = WorkerStatus::Draining;
963        pool.register(w);
964        let selected = pool.select_and_reserve_with_tag("gpu", "build");
965        assert_eq!(selected, None);
966    }
967
968    #[test]
969    fn test_select_and_reserve_with_tag_reserves_capacity() {
970        let pool = WorkerPool::new(15_000);
971        pool.register(make_tagged_worker("w1", vec!["build"], 2, vec!["gpu"]));
972        pool.select_and_reserve_with_tag("gpu", "build");
973        let stats = pool.stats();
974        assert_eq!(stats.used_capacity, 1);
975    }
976
977    #[test]
978    fn test_select_and_reserve_with_tag_requires_task_type_match() {
979        let pool = WorkerPool::new(15_000);
980        pool.register(make_tagged_worker("w1", vec!["build"], 5, vec!["gpu"]));
981        let selected = pool.select_and_reserve_with_tag("gpu", "deploy");
982        assert_eq!(selected, None);
983    }
984
985    // =========================================================================
986    // Reserve specific worker tests
987    // =========================================================================
988
989    #[test]
990    fn test_reserve_specific_worker_success() {
991        let pool = WorkerPool::new(15_000);
992        pool.register(make_worker("w1", vec!["a"], 5));
993        assert!(pool.reserve_specific_worker("w1").is_ok());
994        let stats = pool.stats();
995        assert_eq!(stats.used_capacity, 1);
996    }
997
998    #[test]
999    fn test_reserve_specific_worker_not_found() {
1000        let pool = WorkerPool::new(15_000);
1001        let err = pool.reserve_specific_worker("ghost").unwrap_err();
1002        assert_eq!(
1003            err,
1004            PoolError::WorkerNotFound {
1005                worker_id: "ghost".to_string()
1006            }
1007        );
1008    }
1009
1010    #[test]
1011    fn test_reserve_specific_worker_at_capacity_existing() {
1012        let pool = WorkerPool::new(15_000);
1013        let mut w = make_worker("w1", vec!["a"], 1);
1014        w.active_tasks = 1;
1015        w.status = WorkerStatus::Busy;
1016        pool.register(w);
1017        let err = pool.reserve_specific_worker("w1").unwrap_err();
1018        assert!(matches!(err, PoolError::WorkerAtCapacity { .. }));
1019    }
1020
1021    #[test]
1022    fn test_reserve_specific_worker_draining() {
1023        let pool = WorkerPool::new(15_000);
1024        let mut w = make_worker("w1", vec!["a"], 5);
1025        w.status = WorkerStatus::Draining;
1026        pool.register(w);
1027        let err = pool.reserve_specific_worker("w1").unwrap_err();
1028        assert!(matches!(err, PoolError::WorkerUnavailable { .. }));
1029    }
1030
1031    #[test]
1032    fn test_reserve_specific_worker_at_capacity() {
1033        let pool = WorkerPool::new(15_000);
1034        let mut w = make_worker("w1", vec!["a"], 1);
1035        w.active_tasks = 1;
1036        pool.register(w);
1037        let err = pool.reserve_specific_worker("w1").unwrap_err();
1038        assert!(matches!(err, PoolError::WorkerAtCapacity { .. }));
1039    }
1040
1041    #[test]
1042    fn test_heartbeat_does_not_overwrite_draining() {
1043        let pool = WorkerPool::new(15_000);
1044        pool.register(make_worker("w1", vec!["a"], 5));
1045        pool.drain_worker("w1").unwrap();
1046        pool.heartbeat("w1", 1);
1047        let workers = pool.workers();
1048        assert_eq!(workers[0].status, WorkerStatus::Draining);
1049    }
1050
1051    #[test]
1052    fn test_reregistration_allowed_at_max_capacity() {
1053        let pool = WorkerPool::with_limits(15_000, Some(1), None, None);
1054        pool.register(make_worker("w1", vec!["a"], 5));
1055        // Re-register same worker — should succeed even at max
1056        pool.register(make_worker("w1", vec!["a", "b"], 10));
1057        assert_eq!(pool.count(), 1);
1058    }
1059
1060    #[test]
1061    fn test_register_does_not_fire_below_min_callback() {
1062        use std::sync::atomic::{AtomicU32, Ordering};
1063        let counter = Arc::new(AtomicU32::new(0));
1064        let c = counter.clone();
1065        let pool = WorkerPool::with_limits(
1066            15_000,
1067            None,
1068            Some(5),
1069            Some(Arc::new(move |_| {
1070                c.fetch_add(1, Ordering::SeqCst);
1071            })),
1072        );
1073        pool.register(make_worker("w1", vec!["a"], 5));
1074        // Should NOT fire on register — only on worker loss
1075        assert_eq!(counter.load(Ordering::SeqCst), 0);
1076    }
1077
1078    // =========================================================================
1079    // WorkerInfo serde with tags tests
1080    // =========================================================================
1081
1082    #[test]
1083    fn test_worker_info_serde_with_tags() {
1084        let w = make_tagged_worker("w1", vec!["build"], 5, vec!["gpu", "region-a"]);
1085        let json = serde_json::to_string(&w).unwrap();
1086        let de: WorkerInfo = serde_json::from_str(&json).unwrap();
1087        assert_eq!(de.tags, vec!["gpu".to_string(), "region-a".to_string()]);
1088    }
1089
1090    #[test]
1091    fn test_worker_info_serde_without_tags_defaults_empty() {
1092        // Simulate old JSON without tags field
1093        let json = r#"{
1094            "id": "w1",
1095            "language": "TypeScript",
1096            "supported_tasks": ["a"],
1097            "max_concurrency": 5,
1098            "status": "Active",
1099            "active_tasks": 0,
1100            "registered_at": "2024-01-01T00:00:00Z",
1101            "last_heartbeat": "2024-01-01T00:00:00Z"
1102        }"#;
1103        let de: WorkerInfo = serde_json::from_str(json).unwrap();
1104        assert!(de.tags.is_empty());
1105    }
1106
1107    // =========================================================================
1108    // PoolError display tests
1109    // =========================================================================
1110
1111    #[test]
1112    fn test_pool_error_display_pool_full() {
1113        let err = PoolError::PoolFull { max: 10 };
1114        assert!(err.to_string().contains("maximum capacity"));
1115        assert!(err.to_string().contains("10"));
1116    }
1117
1118    #[test]
1119    fn test_pool_error_display_worker_not_found() {
1120        let err = PoolError::WorkerNotFound {
1121            worker_id: "abc".to_string(),
1122        };
1123        assert!(err.to_string().contains("abc"));
1124    }
1125
1126    // =========================================================================
1127    // with_limits constructor tests
1128    // =========================================================================
1129
1130    #[test]
1131    fn test_with_limits_sets_max() {
1132        let pool = WorkerPool::with_limits(15_000, Some(5), None, None);
1133        pool.register(make_worker("w1", vec!["a"], 1));
1134        pool.register(make_worker("w2", vec!["a"], 1));
1135        pool.register(make_worker("w3", vec!["a"], 1));
1136        pool.register(make_worker("w4", vec!["a"], 1));
1137        pool.register(make_worker("w5", vec!["a"], 1));
1138        pool.register(make_worker("w6", vec!["a"], 1)); // rejected
1139        assert_eq!(pool.count(), 5);
1140    }
1141}