Skip to main content

somatize_worker/
coordinator.rs

1//! Coordinator — lightweight gateway that manages worker registration,
2//! routing, and health monitoring.
3//!
4//! Can run as:
5//! - **Standalone binary**: `soma-coordinator --token sk-xxx --port 9090`
6//! - **Embedded**: `Coordinator::new().start_local()` for development
7//!
8//! The coordinator does NOT execute plans. It:
9//! 1. Accepts worker registrations (with capabilities + heartbeats)
10//! 2. Authenticates connections via bearer token
11//! 3. Routes client plan submissions to appropriate workers
12//! 4. Forwards worker events back to the client
13
14use crate::protocol::{Capabilities, LoadMetrics, WorkerId};
15use chrono::{DateTime, Utc};
16use serde::{Deserialize, Serialize};
17use std::collections::HashMap;
18use std::sync::{Arc, RwLock};
19
20/// Status of a registered worker.
21#[derive(Debug, Clone, Serialize, Deserialize)]
22pub struct WorkerStatus {
23    pub id: WorkerId,
24    pub address: String,
25    pub capabilities: Capabilities,
26    pub load: Option<LoadMetrics>,
27    pub active_plans: Vec<String>,
28    pub last_heartbeat: DateTime<Utc>,
29    pub connected: bool,
30}
31
32impl WorkerStatus {
33    /// Whether the worker has capacity for more work.
34    pub fn has_capacity(&self, max_concurrent: usize) -> bool {
35        self.connected && self.active_plans.len() < max_concurrent
36    }
37
38    /// Whether the worker matches a set of required tags.
39    pub fn matches_tags(&self, required: &[String]) -> bool {
40        required
41            .iter()
42            .all(|tag| self.capabilities.tags.contains(tag))
43    }
44
45    /// Whether the worker is considered alive (heartbeat within timeout).
46    pub fn is_alive(&self, timeout_secs: i64) -> bool {
47        self.connected && (Utc::now() - self.last_heartbeat).num_seconds() < timeout_secs
48    }
49}
50
51/// The worker registry — tracks all known workers and their status.
52#[derive(Debug, Clone)]
53pub struct WorkerRegistry {
54    workers: Arc<RwLock<HashMap<WorkerId, WorkerStatus>>>,
55    heartbeat_timeout_secs: i64,
56}
57
58impl WorkerRegistry {
59    pub fn new() -> Self {
60        Self {
61            workers: Arc::new(RwLock::new(HashMap::new())),
62            heartbeat_timeout_secs: 30,
63        }
64    }
65
66    pub fn with_heartbeat_timeout(mut self, secs: i64) -> Self {
67        self.heartbeat_timeout_secs = secs;
68        self
69    }
70
71    /// Register a new worker or update an existing one.
72    pub fn register(
73        &self,
74        id: impl Into<String>,
75        address: impl Into<String>,
76        capabilities: Capabilities,
77    ) {
78        let id = id.into();
79        let mut workers = self.workers.write().unwrap();
80        workers.insert(
81            id.clone(),
82            WorkerStatus {
83                id,
84                address: address.into(),
85                capabilities,
86                load: None,
87                active_plans: vec![],
88                last_heartbeat: Utc::now(),
89                connected: true,
90            },
91        );
92    }
93
94    /// Update a worker's heartbeat and load metrics.
95    pub fn heartbeat(&self, worker_id: &str, load: LoadMetrics) {
96        let mut workers = self.workers.write().unwrap();
97        if let Some(w) = workers.get_mut(worker_id) {
98            w.load = Some(load);
99            w.last_heartbeat = Utc::now();
100        }
101    }
102
103    /// Mark a worker as disconnected.
104    pub fn disconnect(&self, worker_id: &str) {
105        let mut workers = self.workers.write().unwrap();
106        if let Some(w) = workers.get_mut(worker_id) {
107            w.connected = false;
108        }
109    }
110
111    /// Remove a worker entirely.
112    pub fn remove(&self, worker_id: &str) {
113        let mut workers = self.workers.write().unwrap();
114        workers.remove(worker_id);
115    }
116
117    /// Get all alive, connected workers.
118    pub fn active_workers(&self) -> Vec<WorkerStatus> {
119        let workers = self.workers.read().unwrap();
120        workers
121            .values()
122            .filter(|w| w.is_alive(self.heartbeat_timeout_secs))
123            .cloned()
124            .collect()
125    }
126
127    /// Get a specific worker by ID.
128    pub fn get(&self, worker_id: &str) -> Option<WorkerStatus> {
129        let workers = self.workers.read().unwrap();
130        workers.get(worker_id).cloned()
131    }
132
133    /// Find workers matching required tags with available capacity.
134    pub fn find_workers(&self, tags: &[String], max_concurrent: usize) -> Vec<WorkerStatus> {
135        self.active_workers()
136            .into_iter()
137            .filter(|w| w.matches_tags(tags) && w.has_capacity(max_concurrent))
138            .collect()
139    }
140
141    /// Total number of registered workers (including disconnected).
142    pub fn total_count(&self) -> usize {
143        self.workers.read().unwrap().len()
144    }
145
146    /// Number of alive, connected workers.
147    pub fn active_count(&self) -> usize {
148        self.active_workers().len()
149    }
150
151    /// Human-readable summary.
152    pub fn summary(&self) -> String {
153        let workers = self.active_workers();
154        let total_cpus: usize = workers.iter().map(|w| w.capabilities.cpu_cores).sum();
155        let total_gpus: usize = workers.iter().map(|w| w.capabilities.gpus.len()).sum();
156        let total_ram: u64 = workers.iter().map(|w| w.capabilities.ram_bytes).sum();
157        format!(
158            "{} workers ({} CPUs, {} GPUs, {:.1} GB RAM)",
159            workers.len(),
160            total_cpus,
161            total_gpus,
162            total_ram as f64 / (1024.0 * 1024.0 * 1024.0),
163        )
164    }
165
166    /// Prune workers that haven't sent a heartbeat within the timeout.
167    pub fn prune_stale(&self) {
168        let mut workers = self.workers.write().unwrap();
169        let timeout = self.heartbeat_timeout_secs;
170        workers.retain(|_, w| w.is_alive(timeout) || w.connected);
171    }
172}
173
174impl Default for WorkerRegistry {
175    fn default() -> Self {
176        Self::new()
177    }
178}
179
180#[cfg(test)]
181mod tests {
182    use super::*;
183    use crate::protocol::GpuInfo;
184
185    fn test_caps(tags: Vec<String>) -> Capabilities {
186        Capabilities {
187            cpu_cores: 4,
188            ram_bytes: 8_000_000_000,
189            gpus: vec![],
190            python_envs: vec![],
191            tags,
192        }
193    }
194
195    fn gpu_caps() -> Capabilities {
196        Capabilities {
197            cpu_cores: 8,
198            ram_bytes: 32_000_000_000,
199            gpus: vec![GpuInfo {
200                name: "A100".into(),
201                memory_bytes: 80_000_000_000,
202            }],
203            python_envs: vec![],
204            tags: vec!["gpu".into(), "training".into()],
205        }
206    }
207
208    #[test]
209    fn register_and_query() {
210        let registry = WorkerRegistry::new();
211        registry.register("w1", "ws://host1:8080", test_caps(vec!["cpu".into()]));
212        registry.register("w2", "ws://host2:8080", gpu_caps());
213
214        assert_eq!(registry.total_count(), 2);
215        assert_eq!(registry.active_count(), 2);
216
217        let w1 = registry.get("w1").unwrap();
218        assert_eq!(w1.address, "ws://host1:8080");
219        assert!(w1.connected);
220    }
221
222    #[test]
223    fn find_by_tags() {
224        let registry = WorkerRegistry::new();
225        registry.register("cpu1", "ws://c1:8080", test_caps(vec!["cpu".into()]));
226        registry.register("gpu1", "ws://g1:8080", gpu_caps());
227
228        let gpu_workers = registry.find_workers(&["gpu".into()], 10);
229        assert_eq!(gpu_workers.len(), 1);
230        assert_eq!(gpu_workers[0].id, "gpu1");
231
232        let cpu_workers = registry.find_workers(&["cpu".into()], 10);
233        assert_eq!(cpu_workers.len(), 1);
234    }
235
236    #[test]
237    fn disconnect_and_reconnect() {
238        let registry = WorkerRegistry::new();
239        registry.register("w1", "ws://host1:8080", test_caps(vec![]));
240        assert_eq!(registry.active_count(), 1);
241
242        registry.disconnect("w1");
243        assert_eq!(registry.active_count(), 0);
244
245        // Re-register = reconnect
246        registry.register("w1", "ws://host1:8080", test_caps(vec![]));
247        assert_eq!(registry.active_count(), 1);
248    }
249
250    #[test]
251    fn summary_format() {
252        let registry = WorkerRegistry::new();
253        registry.register("w1", "ws://h1:8080", test_caps(vec![]));
254        registry.register("w2", "ws://h2:8080", gpu_caps());
255
256        let s = registry.summary();
257        assert!(s.contains("2 workers"));
258        assert!(s.contains("12 CPUs")); // 4 + 8
259        assert!(s.contains("1 GPUs"));
260    }
261
262    #[test]
263    fn capacity_check() {
264        let registry = WorkerRegistry::new();
265        registry.register("w1", "ws://h1:8080", test_caps(vec![]));
266
267        // With max_concurrent=0, no one has capacity
268        let workers = registry.find_workers(&[], 0);
269        assert!(workers.is_empty());
270
271        // With max_concurrent=1, worker with 0 active plans has capacity
272        let workers = registry.find_workers(&[], 1);
273        assert_eq!(workers.len(), 1);
274    }
275}