1use crate::protocol::{Capabilities, LoadMetrics, WorkerId};
15use chrono::{DateTime, Utc};
16use serde::{Deserialize, Serialize};
17use std::collections::HashMap;
18use std::sync::{Arc, RwLock};
19
20#[derive(Debug, Clone, Serialize, Deserialize)]
22pub struct WorkerStatus {
23 pub id: WorkerId,
24 pub address: String,
25 pub capabilities: Capabilities,
26 pub load: Option<LoadMetrics>,
27 pub active_plans: Vec<String>,
28 pub last_heartbeat: DateTime<Utc>,
29 pub connected: bool,
30}
31
32impl WorkerStatus {
33 pub fn has_capacity(&self, max_concurrent: usize) -> bool {
35 self.connected && self.active_plans.len() < max_concurrent
36 }
37
38 pub fn matches_tags(&self, required: &[String]) -> bool {
40 required
41 .iter()
42 .all(|tag| self.capabilities.tags.contains(tag))
43 }
44
45 pub fn is_alive(&self, timeout_secs: i64) -> bool {
47 self.connected && (Utc::now() - self.last_heartbeat).num_seconds() < timeout_secs
48 }
49}
50
51#[derive(Debug, Clone)]
53pub struct WorkerRegistry {
54 workers: Arc<RwLock<HashMap<WorkerId, WorkerStatus>>>,
55 heartbeat_timeout_secs: i64,
56}
57
58impl WorkerRegistry {
59 pub fn new() -> Self {
60 Self {
61 workers: Arc::new(RwLock::new(HashMap::new())),
62 heartbeat_timeout_secs: 30,
63 }
64 }
65
66 pub fn with_heartbeat_timeout(mut self, secs: i64) -> Self {
67 self.heartbeat_timeout_secs = secs;
68 self
69 }
70
71 pub fn register(
73 &self,
74 id: impl Into<String>,
75 address: impl Into<String>,
76 capabilities: Capabilities,
77 ) {
78 let id = id.into();
79 let mut workers = self.workers.write().unwrap();
80 workers.insert(
81 id.clone(),
82 WorkerStatus {
83 id,
84 address: address.into(),
85 capabilities,
86 load: None,
87 active_plans: vec![],
88 last_heartbeat: Utc::now(),
89 connected: true,
90 },
91 );
92 }
93
94 pub fn heartbeat(&self, worker_id: &str, load: LoadMetrics) {
96 let mut workers = self.workers.write().unwrap();
97 if let Some(w) = workers.get_mut(worker_id) {
98 w.load = Some(load);
99 w.last_heartbeat = Utc::now();
100 }
101 }
102
103 pub fn disconnect(&self, worker_id: &str) {
105 let mut workers = self.workers.write().unwrap();
106 if let Some(w) = workers.get_mut(worker_id) {
107 w.connected = false;
108 }
109 }
110
111 pub fn remove(&self, worker_id: &str) {
113 let mut workers = self.workers.write().unwrap();
114 workers.remove(worker_id);
115 }
116
117 pub fn active_workers(&self) -> Vec<WorkerStatus> {
119 let workers = self.workers.read().unwrap();
120 workers
121 .values()
122 .filter(|w| w.is_alive(self.heartbeat_timeout_secs))
123 .cloned()
124 .collect()
125 }
126
127 pub fn get(&self, worker_id: &str) -> Option<WorkerStatus> {
129 let workers = self.workers.read().unwrap();
130 workers.get(worker_id).cloned()
131 }
132
133 pub fn find_workers(&self, tags: &[String], max_concurrent: usize) -> Vec<WorkerStatus> {
135 self.active_workers()
136 .into_iter()
137 .filter(|w| w.matches_tags(tags) && w.has_capacity(max_concurrent))
138 .collect()
139 }
140
141 pub fn total_count(&self) -> usize {
143 self.workers.read().unwrap().len()
144 }
145
146 pub fn active_count(&self) -> usize {
148 self.active_workers().len()
149 }
150
151 pub fn summary(&self) -> String {
153 let workers = self.active_workers();
154 let total_cpus: usize = workers.iter().map(|w| w.capabilities.cpu_cores).sum();
155 let total_gpus: usize = workers.iter().map(|w| w.capabilities.gpus.len()).sum();
156 let total_ram: u64 = workers.iter().map(|w| w.capabilities.ram_bytes).sum();
157 format!(
158 "{} workers ({} CPUs, {} GPUs, {:.1} GB RAM)",
159 workers.len(),
160 total_cpus,
161 total_gpus,
162 total_ram as f64 / (1024.0 * 1024.0 * 1024.0),
163 )
164 }
165
166 pub fn prune_stale(&self) {
168 let mut workers = self.workers.write().unwrap();
169 let timeout = self.heartbeat_timeout_secs;
170 workers.retain(|_, w| w.is_alive(timeout) || w.connected);
171 }
172}
173
174impl Default for WorkerRegistry {
175 fn default() -> Self {
176 Self::new()
177 }
178}
179
180#[cfg(test)]
181mod tests {
182 use super::*;
183 use crate::protocol::GpuInfo;
184
185 fn test_caps(tags: Vec<String>) -> Capabilities {
186 Capabilities {
187 cpu_cores: 4,
188 ram_bytes: 8_000_000_000,
189 gpus: vec![],
190 python_envs: vec![],
191 tags,
192 }
193 }
194
195 fn gpu_caps() -> Capabilities {
196 Capabilities {
197 cpu_cores: 8,
198 ram_bytes: 32_000_000_000,
199 gpus: vec![GpuInfo {
200 name: "A100".into(),
201 memory_bytes: 80_000_000_000,
202 }],
203 python_envs: vec![],
204 tags: vec!["gpu".into(), "training".into()],
205 }
206 }
207
208 #[test]
209 fn register_and_query() {
210 let registry = WorkerRegistry::new();
211 registry.register("w1", "ws://host1:8080", test_caps(vec!["cpu".into()]));
212 registry.register("w2", "ws://host2:8080", gpu_caps());
213
214 assert_eq!(registry.total_count(), 2);
215 assert_eq!(registry.active_count(), 2);
216
217 let w1 = registry.get("w1").unwrap();
218 assert_eq!(w1.address, "ws://host1:8080");
219 assert!(w1.connected);
220 }
221
222 #[test]
223 fn find_by_tags() {
224 let registry = WorkerRegistry::new();
225 registry.register("cpu1", "ws://c1:8080", test_caps(vec!["cpu".into()]));
226 registry.register("gpu1", "ws://g1:8080", gpu_caps());
227
228 let gpu_workers = registry.find_workers(&["gpu".into()], 10);
229 assert_eq!(gpu_workers.len(), 1);
230 assert_eq!(gpu_workers[0].id, "gpu1");
231
232 let cpu_workers = registry.find_workers(&["cpu".into()], 10);
233 assert_eq!(cpu_workers.len(), 1);
234 }
235
236 #[test]
237 fn disconnect_and_reconnect() {
238 let registry = WorkerRegistry::new();
239 registry.register("w1", "ws://host1:8080", test_caps(vec![]));
240 assert_eq!(registry.active_count(), 1);
241
242 registry.disconnect("w1");
243 assert_eq!(registry.active_count(), 0);
244
245 registry.register("w1", "ws://host1:8080", test_caps(vec![]));
247 assert_eq!(registry.active_count(), 1);
248 }
249
250 #[test]
251 fn summary_format() {
252 let registry = WorkerRegistry::new();
253 registry.register("w1", "ws://h1:8080", test_caps(vec![]));
254 registry.register("w2", "ws://h2:8080", gpu_caps());
255
256 let s = registry.summary();
257 assert!(s.contains("2 workers"));
258 assert!(s.contains("12 CPUs")); assert!(s.contains("1 GPUs"));
260 }
261
262 #[test]
263 fn capacity_check() {
264 let registry = WorkerRegistry::new();
265 registry.register("w1", "ws://h1:8080", test_caps(vec![]));
266
267 let workers = registry.find_workers(&[], 0);
269 assert!(workers.is_empty());
270
271 let workers = registry.find_workers(&[], 1);
273 assert_eq!(workers.len(), 1);
274 }
275}