Skip to main content

rust_pipe/worker/
mod.rs

1/// Worker pool management and least-loaded selection.
2use chrono::{DateTime, Utc};
3use dashmap::DashMap;
4use serde::{Deserialize, Serialize};
5
6use crate::transport::WorkerLanguage;
7
8#[derive(Debug, Clone, Serialize, Deserialize)]
9/// Information about a connected worker.
10pub struct WorkerInfo {
11    pub id: String,
12    pub language: WorkerLanguage,
13    pub supported_tasks: Vec<String>,
14    pub max_concurrency: u32,
15    pub status: WorkerStatus,
16    pub active_tasks: u32,
17    pub registered_at: DateTime<Utc>,
18    pub last_heartbeat: DateTime<Utc>,
19}
20
21#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq)]
22/// Current status of a worker in the pool.
23pub enum WorkerStatus {
24    Active,
25    Busy,
26    Draining,
27    Dead,
28}
29
30/// Pool of connected workers with capacity-aware task routing.
31pub struct WorkerPool {
32    workers: DashMap<String, WorkerInfo>,
33    heartbeat_timeout_ms: u64,
34}
35
36impl WorkerPool {
37    pub fn new(heartbeat_timeout_ms: u64) -> Self {
38        Self {
39            workers: DashMap::new(),
40            heartbeat_timeout_ms,
41        }
42    }
43
44    pub fn register(&self, info: WorkerInfo) {
45        tracing::info!(
46            worker_id = %info.id,
47            language = ?info.language,
48            tasks = ?info.supported_tasks,
49            "Registering worker"
50        );
51        self.workers.insert(info.id.clone(), info);
52    }
53
54    pub fn deregister(&self, worker_id: &str) {
55        self.workers.remove(worker_id);
56    }
57
58    pub fn heartbeat(&self, worker_id: &str, active_tasks: u32) {
59        if let Some(mut worker) = self.workers.get_mut(worker_id) {
60            worker.last_heartbeat = Utc::now();
61            worker.active_tasks = active_tasks;
62            worker.status = if active_tasks >= worker.max_concurrency {
63                WorkerStatus::Busy
64            } else {
65                WorkerStatus::Active
66            };
67        }
68    }
69
70    /// Atomically selects a worker and reserves capacity.
71    /// Returns the worker ID if one is available, or None.
72    /// This avoids the TOCTOU race between select and dispatch.
73    pub fn select_and_reserve(&self, task_type: &str) -> Option<String> {
74        let mut best_id: Option<String> = None;
75        let mut best_capacity: u32 = 0;
76
77        for entry in self.workers.iter() {
78            let worker = entry.value();
79            if worker.status == WorkerStatus::Dead || worker.status == WorkerStatus::Draining {
80                continue;
81            }
82            if !worker.supported_tasks.iter().any(|t| t == task_type) {
83                continue;
84            }
85            if worker.active_tasks >= worker.max_concurrency {
86                continue;
87            }
88
89            let available = worker.max_concurrency - worker.active_tasks;
90            if best_id.is_none() || available > best_capacity {
91                best_id = Some(worker.id.clone());
92                best_capacity = available;
93            }
94        }
95
96        let worker_id = best_id?;
97
98        // Atomically increment under the entry lock
99        if let Some(mut worker) = self.workers.get_mut(&worker_id) {
100            if worker.active_tasks >= worker.max_concurrency {
101                return None;
102            }
103            worker.active_tasks += 1;
104            if worker.active_tasks >= worker.max_concurrency {
105                worker.status = WorkerStatus::Busy;
106            }
107            Some(worker_id)
108        } else {
109            None
110        }
111    }
112
113    /// Selects the least-loaded worker for a task type WITHOUT modifying state.
114    /// Use `select_and_reserve` for dispatch to avoid TOCTOU races.
115    pub fn select_worker(&self, task_type: &str) -> Option<String> {
116        let mut best_id: Option<String> = None;
117        let mut best_capacity: u32 = 0;
118
119        for entry in self.workers.iter() {
120            let worker = entry.value();
121            if worker.status == WorkerStatus::Dead || worker.status == WorkerStatus::Draining {
122                continue;
123            }
124            if !worker.supported_tasks.iter().any(|t| t == task_type) {
125                continue;
126            }
127            if worker.active_tasks >= worker.max_concurrency {
128                continue;
129            }
130
131            let available = worker.max_concurrency - worker.active_tasks;
132            if best_id.is_none() || available > best_capacity {
133                best_id = Some(worker.id.clone());
134                best_capacity = available;
135            }
136        }
137
138        best_id
139    }
140
141    pub fn mark_task_dispatched(&self, worker_id: &str) {
142        if let Some(mut worker) = self.workers.get_mut(worker_id) {
143            worker.active_tasks += 1;
144            if worker.active_tasks >= worker.max_concurrency {
145                worker.status = WorkerStatus::Busy;
146            }
147        }
148    }
149
150    pub fn mark_task_completed(&self, worker_id: &str) {
151        if let Some(mut worker) = self.workers.get_mut(worker_id) {
152            worker.active_tasks = worker.active_tasks.saturating_sub(1);
153            if worker.active_tasks < worker.max_concurrency && worker.status == WorkerStatus::Busy {
154                worker.status = WorkerStatus::Active;
155            }
156        }
157    }
158
159    pub fn detect_dead_workers(&self) -> Vec<String> {
160        let now = Utc::now();
161        let mut dead = Vec::new();
162
163        for mut entry in self.workers.iter_mut() {
164            let elapsed_ms = (now - entry.last_heartbeat).num_milliseconds().max(0) as u64;
165            if elapsed_ms > self.heartbeat_timeout_ms {
166                entry.status = WorkerStatus::Dead;
167                dead.push(entry.id.clone());
168            }
169        }
170
171        dead
172    }
173
174    pub fn active_workers(&self) -> Vec<WorkerInfo> {
175        self.workers
176            .iter()
177            .filter(|w| w.status == WorkerStatus::Active || w.status == WorkerStatus::Busy)
178            .map(|w| w.value().clone())
179            .collect()
180    }
181
182    pub fn count(&self) -> usize {
183        self.workers.len()
184    }
185
186    pub fn stats(&self) -> PoolStats {
187        let mut stats = PoolStats::default();
188        for entry in self.workers.iter() {
189            stats.total += 1;
190            match entry.status {
191                WorkerStatus::Active => stats.active += 1,
192                WorkerStatus::Busy => stats.busy += 1,
193                WorkerStatus::Draining => stats.draining += 1,
194                WorkerStatus::Dead => stats.dead += 1,
195            }
196            stats.total_capacity += entry.max_concurrency;
197            stats.used_capacity += entry.active_tasks;
198        }
199        stats
200    }
201}
202
203#[derive(Debug, Default, Serialize)]
204/// Aggregate statistics about the worker pool.
205pub struct PoolStats {
206    pub total: u32,
207    pub active: u32,
208    pub busy: u32,
209    pub draining: u32,
210    pub dead: u32,
211    pub total_capacity: u32,
212    pub used_capacity: u32,
213}
214
215#[cfg(test)]
216mod tests {
217    use super::*;
218    use crate::transport::WorkerLanguage;
219
220    fn make_worker(id: &str, tasks: Vec<&str>, max_concurrency: u32) -> WorkerInfo {
221        WorkerInfo {
222            id: id.to_string(),
223            language: WorkerLanguage::TypeScript,
224            supported_tasks: tasks.into_iter().map(String::from).collect(),
225            max_concurrency,
226            status: WorkerStatus::Active,
227            active_tasks: 0,
228            registered_at: Utc::now(),
229            last_heartbeat: Utc::now(),
230        }
231    }
232
233    #[test]
234    fn test_pool_new_empty() {
235        let pool = WorkerPool::new(15_000);
236        assert_eq!(pool.count(), 0);
237    }
238
239    #[test]
240    fn test_register_single_worker() {
241        let pool = WorkerPool::new(15_000);
242        pool.register(make_worker("w1", vec!["task-a"], 5));
243        assert_eq!(pool.count(), 1);
244    }
245
246    #[test]
247    fn test_register_multiple_workers() {
248        let pool = WorkerPool::new(15_000);
249        pool.register(make_worker("w1", vec!["a"], 5));
250        pool.register(make_worker("w2", vec!["b"], 5));
251        pool.register(make_worker("w3", vec!["c"], 5));
252        assert_eq!(pool.count(), 3);
253    }
254
255    #[test]
256    fn test_register_overwrites_same_id() {
257        let pool = WorkerPool::new(15_000);
258        pool.register(make_worker("w1", vec!["a"], 5));
259        pool.register(make_worker("w1", vec!["b", "c"], 10));
260        assert_eq!(pool.count(), 1);
261    }
262
263    #[test]
264    fn test_deregister_existing_worker() {
265        let pool = WorkerPool::new(15_000);
266        pool.register(make_worker("w1", vec!["a"], 5));
267        pool.deregister("w1");
268        assert_eq!(pool.count(), 0);
269    }
270
271    #[test]
272    fn test_deregister_nonexistent_worker() {
273        let pool = WorkerPool::new(15_000);
274        pool.deregister("ghost");
275        assert_eq!(pool.count(), 0);
276    }
277
278    #[test]
279    fn test_heartbeat_updates_active_tasks() {
280        let pool = WorkerPool::new(15_000);
281        pool.register(make_worker("w1", vec!["a"], 5));
282        pool.heartbeat("w1", 3);
283        let stats = pool.stats();
284        assert_eq!(stats.used_capacity, 3);
285    }
286
287    #[test]
288    fn test_heartbeat_sets_busy_when_at_capacity() {
289        let pool = WorkerPool::new(15_000);
290        pool.register(make_worker("w1", vec!["a"], 2));
291        pool.heartbeat("w1", 2);
292        let stats = pool.stats();
293        assert_eq!(stats.busy, 1);
294        assert_eq!(stats.active, 0);
295    }
296
297    #[test]
298    fn test_heartbeat_sets_active_when_below_capacity() {
299        let pool = WorkerPool::new(15_000);
300        pool.register(make_worker("w1", vec!["a"], 5));
301        pool.heartbeat("w1", 3);
302        let stats = pool.stats();
303        assert_eq!(stats.active, 1);
304    }
305
306    #[test]
307    fn test_heartbeat_nonexistent_worker_is_noop() {
308        let pool = WorkerPool::new(15_000);
309        pool.heartbeat("ghost", 1);
310        assert_eq!(pool.count(), 0);
311    }
312
313    #[test]
314    fn test_select_worker_single_matching() {
315        let pool = WorkerPool::new(15_000);
316        pool.register(make_worker("w1", vec!["build"], 5));
317        assert_eq!(pool.select_worker("build"), Some("w1".to_string()));
318    }
319
320    #[test]
321    fn test_select_worker_returns_none_when_no_matching_type() {
322        let pool = WorkerPool::new(15_000);
323        pool.register(make_worker("w1", vec!["build"], 5));
324        assert_eq!(pool.select_worker("deploy"), None);
325    }
326
327    #[test]
328    fn test_select_worker_returns_none_when_pool_empty() {
329        let pool = WorkerPool::new(15_000);
330        assert_eq!(pool.select_worker("any"), None);
331    }
332
333    #[test]
334    fn test_select_worker_picks_least_loaded() {
335        let pool = WorkerPool::new(15_000);
336        let mut w1 = make_worker("w1", vec!["build"], 5);
337        w1.active_tasks = 4;
338        let mut w2 = make_worker("w2", vec!["build"], 5);
339        w2.active_tasks = 1;
340        pool.register(w1);
341        pool.register(w2);
342        assert_eq!(pool.select_worker("build"), Some("w2".to_string()));
343    }
344
345    #[test]
346    fn test_select_worker_skips_dead_worker() {
347        let pool = WorkerPool::new(15_000);
348        let mut w = make_worker("w1", vec!["build"], 5);
349        w.status = WorkerStatus::Dead;
350        pool.register(w);
351        assert_eq!(pool.select_worker("build"), None);
352    }
353
354    #[test]
355    fn test_select_worker_skips_draining_worker() {
356        let pool = WorkerPool::new(15_000);
357        let mut w = make_worker("w1", vec!["build"], 5);
358        w.status = WorkerStatus::Draining;
359        pool.register(w);
360        assert_eq!(pool.select_worker("build"), None);
361    }
362
363    #[test]
364    fn test_select_worker_skips_at_capacity() {
365        let pool = WorkerPool::new(15_000);
366        let mut w = make_worker("w1", vec!["build"], 2);
367        w.active_tasks = 2;
368        pool.register(w);
369        assert_eq!(pool.select_worker("build"), None);
370    }
371
372    #[test]
373    fn test_select_worker_multiple_task_types() {
374        let pool = WorkerPool::new(15_000);
375        pool.register(make_worker("w1", vec!["a", "b", "c"], 5));
376        assert_eq!(pool.select_worker("b"), Some("w1".to_string()));
377    }
378
379    #[test]
380    fn test_mark_task_dispatched_increments() {
381        let pool = WorkerPool::new(15_000);
382        pool.register(make_worker("w1", vec!["a"], 5));
383        pool.mark_task_dispatched("w1");
384        let stats = pool.stats();
385        assert_eq!(stats.used_capacity, 1);
386    }
387
388    #[test]
389    fn test_mark_task_dispatched_sets_busy_at_capacity() {
390        let pool = WorkerPool::new(15_000);
391        pool.register(make_worker("w1", vec!["a"], 1));
392        pool.mark_task_dispatched("w1");
393        let stats = pool.stats();
394        assert_eq!(stats.busy, 1);
395    }
396
397    #[test]
398    fn test_mark_task_dispatched_nonexistent_is_noop() {
399        let pool = WorkerPool::new(15_000);
400        pool.mark_task_dispatched("ghost");
401    }
402
403    #[test]
404    fn test_mark_task_completed_decrements() {
405        let pool = WorkerPool::new(15_000);
406        let mut w = make_worker("w1", vec!["a"], 5);
407        w.active_tasks = 2;
408        pool.register(w);
409        pool.mark_task_completed("w1");
410        let stats = pool.stats();
411        assert_eq!(stats.used_capacity, 1);
412    }
413
414    #[test]
415    fn test_mark_task_completed_saturating_at_zero() {
416        let pool = WorkerPool::new(15_000);
417        pool.register(make_worker("w1", vec!["a"], 5));
418        pool.mark_task_completed("w1");
419        let stats = pool.stats();
420        assert_eq!(stats.used_capacity, 0);
421    }
422
423    #[test]
424    fn test_mark_task_completed_transitions_busy_to_active() {
425        let pool = WorkerPool::new(15_000);
426        let mut w = make_worker("w1", vec!["a"], 2);
427        w.active_tasks = 2;
428        w.status = WorkerStatus::Busy;
429        pool.register(w);
430        pool.mark_task_completed("w1");
431        let stats = pool.stats();
432        assert_eq!(stats.active, 1);
433        assert_eq!(stats.busy, 0);
434    }
435
436    #[test]
437    fn test_detect_dead_workers_marks_stale() {
438        let pool = WorkerPool::new(100); // 100ms timeout
439        let mut w = make_worker("w1", vec!["a"], 5);
440        w.last_heartbeat = Utc::now() - chrono::Duration::seconds(1);
441        pool.register(w);
442        let dead = pool.detect_dead_workers();
443        assert_eq!(dead, vec!["w1".to_string()]);
444    }
445
446    #[test]
447    fn test_detect_dead_workers_spares_fresh() {
448        let pool = WorkerPool::new(15_000);
449        pool.register(make_worker("w1", vec!["a"], 5));
450        let dead = pool.detect_dead_workers();
451        assert!(dead.is_empty());
452    }
453
454    #[test]
455    fn test_detect_dead_workers_empty_pool() {
456        let pool = WorkerPool::new(15_000);
457        let dead = pool.detect_dead_workers();
458        assert!(dead.is_empty());
459    }
460
461    #[test]
462    fn test_stats_empty_pool() {
463        let pool = WorkerPool::new(15_000);
464        let stats = pool.stats();
465        assert_eq!(stats.total, 0);
466        assert_eq!(stats.active, 0);
467        assert_eq!(stats.total_capacity, 0);
468    }
469
470    #[test]
471    fn test_stats_counts_all_statuses() {
472        let pool = WorkerPool::new(15_000);
473        let mut w1 = make_worker("w1", vec!["a"], 5);
474        w1.status = WorkerStatus::Active;
475        let mut w2 = make_worker("w2", vec!["a"], 5);
476        w2.status = WorkerStatus::Busy;
477        let mut w3 = make_worker("w3", vec!["a"], 5);
478        w3.status = WorkerStatus::Draining;
479        let mut w4 = make_worker("w4", vec!["a"], 5);
480        w4.status = WorkerStatus::Dead;
481        pool.register(w1);
482        pool.register(w2);
483        pool.register(w3);
484        pool.register(w4);
485        let stats = pool.stats();
486        assert_eq!(stats.total, 4);
487        assert_eq!(stats.active, 1);
488        assert_eq!(stats.busy, 1);
489        assert_eq!(stats.draining, 1);
490        assert_eq!(stats.dead, 1);
491    }
492
493    #[test]
494    fn test_stats_capacity_tracking() {
495        let pool = WorkerPool::new(15_000);
496        let mut w = make_worker("w1", vec!["a"], 10);
497        w.active_tasks = 3;
498        pool.register(w);
499        let stats = pool.stats();
500        assert_eq!(stats.total_capacity, 10);
501        assert_eq!(stats.used_capacity, 3);
502    }
503
504    #[test]
505    fn test_active_workers_includes_active_and_busy() {
506        let pool = WorkerPool::new(15_000);
507        let mut w1 = make_worker("w1", vec!["a"], 5);
508        w1.status = WorkerStatus::Active;
509        let mut w2 = make_worker("w2", vec!["a"], 5);
510        w2.status = WorkerStatus::Busy;
511        pool.register(w1);
512        pool.register(w2);
513        assert_eq!(pool.active_workers().len(), 2);
514    }
515
516    #[test]
517    fn test_active_workers_excludes_dead_and_draining() {
518        let pool = WorkerPool::new(15_000);
519        let mut w1 = make_worker("w1", vec!["a"], 5);
520        w1.status = WorkerStatus::Dead;
521        let mut w2 = make_worker("w2", vec!["a"], 5);
522        w2.status = WorkerStatus::Draining;
523        pool.register(w1);
524        pool.register(w2);
525        assert_eq!(pool.active_workers().len(), 0);
526    }
527}