zlayer_proxy/stream/
registry.rs1use dashmap::DashMap;
9use std::collections::HashMap;
10use std::net::SocketAddr;
11use std::sync::atomic::{AtomicUsize, Ordering};
12use std::sync::Arc;
13use std::time::Duration;
14use tokio::sync::RwLock;
15
16#[derive(Debug, Clone, Copy, PartialEq, Eq)]
18pub enum BackendHealth {
19 Healthy,
21 Unhealthy,
23 Unknown,
25}
26
27impl BackendHealth {
28 #[must_use]
30 pub fn is_usable(self) -> bool {
31 matches!(self, BackendHealth::Healthy | BackendHealth::Unknown)
32 }
33}
34
35#[derive(Clone, Debug)]
37pub struct StreamService {
38 pub name: String,
40 pub backends: Vec<SocketAddr>,
42 health: Arc<RwLock<HashMap<SocketAddr, BackendHealth>>>,
44 rr_index: Arc<AtomicUsize>,
46}
47
48impl StreamService {
49 #[must_use]
51 pub fn new(name: String, backends: Vec<SocketAddr>) -> Self {
52 let health: HashMap<SocketAddr, BackendHealth> = backends
53 .iter()
54 .map(|addr| (*addr, BackendHealth::Unknown))
55 .collect();
56 Self {
57 name,
58 backends,
59 health: Arc::new(RwLock::new(health)),
60 rr_index: Arc::new(AtomicUsize::new(0)),
61 }
62 }
63
64 #[must_use]
69 pub fn select_backend(&self) -> Option<SocketAddr> {
70 if self.backends.is_empty() {
71 return None;
72 }
73
74 let len = self.backends.len();
75 let start = self.rr_index.fetch_add(1, Ordering::Relaxed);
76
77 let health_guard = self.health.try_read();
80
81 if let Ok(health) = health_guard {
82 for i in 0..len {
84 let idx = (start + i) % len;
85 let addr = self.backends[idx];
86 let status = health.get(&addr).copied().unwrap_or(BackendHealth::Unknown);
87 if status.is_usable() {
88 return Some(addr);
89 }
90 }
91 }
92
93 Some(self.backends[start % len])
95 }
96
97 pub fn update_backends(&mut self, backends: Vec<SocketAddr>) {
102 let mut health = self
105 .health
106 .try_write()
107 .unwrap_or_else(|_| {
108 tracing::warn!(service = %self.name, "Health map write contention during backend update");
111 unreachable!("update_backends requires exclusive access")
113 });
114
115 for addr in &backends {
117 health.entry(*addr).or_insert(BackendHealth::Unknown);
118 }
119
120 let backend_set: std::collections::HashSet<SocketAddr> = backends.iter().copied().collect();
122 health.retain(|addr, _| backend_set.contains(addr));
123
124 self.backends = backends;
125 }
126
127 pub async fn set_backend_health(&self, addr: SocketAddr, status: BackendHealth) {
129 let mut health = self.health.write().await;
130 if let Some(h) = health.get_mut(&addr) {
131 *h = status;
132 }
133 }
134
135 pub async fn get_backend_health(&self, addr: SocketAddr) -> BackendHealth {
137 let health = self.health.read().await;
138 health.get(&addr).copied().unwrap_or(BackendHealth::Unknown)
139 }
140
141 #[must_use]
143 pub fn backend_count(&self) -> usize {
144 self.backends.len()
145 }
146
147 pub async fn healthy_count(&self) -> usize {
149 let health = self.health.read().await;
150 self.backends
151 .iter()
152 .filter(|addr| {
153 health
154 .get(addr)
155 .copied()
156 .unwrap_or(BackendHealth::Unknown)
157 .is_usable()
158 })
159 .count()
160 }
161}
162
163#[derive(Default)]
167pub struct StreamRegistry {
168 tcp_services: DashMap<u16, StreamService>,
170 udp_services: DashMap<u16, StreamService>,
172}
173
174impl StreamRegistry {
175 #[must_use]
177 pub fn new() -> Self {
178 Self::default()
179 }
180
181 pub fn register_tcp(&self, port: u16, service: StreamService) {
183 tracing::debug!(
184 port = port,
185 service = %service.name,
186 backends = service.backend_count(),
187 "Registered TCP stream service"
188 );
189 self.tcp_services.insert(port, service);
190 }
191
192 pub fn register_udp(&self, port: u16, service: StreamService) {
194 tracing::debug!(
195 port = port,
196 service = %service.name,
197 backends = service.backend_count(),
198 "Registered UDP stream service"
199 );
200 self.udp_services.insert(port, service);
201 }
202
203 #[must_use]
205 pub fn resolve_tcp(&self, port: u16) -> Option<StreamService> {
206 self.tcp_services.get(&port).map(|s| s.clone())
207 }
208
209 #[must_use]
211 pub fn resolve_udp(&self, port: u16) -> Option<StreamService> {
212 self.udp_services.get(&port).map(|s| s.clone())
213 }
214
215 pub fn update_tcp_backends(&self, port: u16, backends: Vec<SocketAddr>) {
217 if let Some(mut service) = self.tcp_services.get_mut(&port) {
218 tracing::debug!(
219 port = port,
220 service = %service.name,
221 old_count = service.backend_count(),
222 new_count = backends.len(),
223 "Updating TCP backends"
224 );
225 service.update_backends(backends);
226 }
227 }
228
229 pub fn update_udp_backends(&self, port: u16, backends: Vec<SocketAddr>) {
231 if let Some(mut service) = self.udp_services.get_mut(&port) {
232 tracing::debug!(
233 port = port,
234 service = %service.name,
235 old_count = service.backend_count(),
236 new_count = backends.len(),
237 "Updating UDP backends"
238 );
239 service.update_backends(backends);
240 }
241 }
242
243 #[must_use]
245 pub fn unregister_tcp(&self, port: u16) -> Option<StreamService> {
246 self.tcp_services.remove(&port).map(|(_, s)| s)
247 }
248
249 #[must_use]
251 pub fn unregister_udp(&self, port: u16) -> Option<StreamService> {
252 self.udp_services.remove(&port).map(|(_, s)| s)
253 }
254
255 #[must_use]
257 pub fn tcp_count(&self) -> usize {
258 self.tcp_services.len()
259 }
260
261 #[must_use]
263 pub fn udp_count(&self) -> usize {
264 self.udp_services.len()
265 }
266
267 #[must_use]
269 pub fn tcp_ports(&self) -> Vec<u16> {
270 self.tcp_services.iter().map(|e| *e.key()).collect()
271 }
272
273 #[must_use]
275 pub fn udp_ports(&self) -> Vec<u16> {
276 self.udp_services.iter().map(|e| *e.key()).collect()
277 }
278
279 #[must_use]
281 pub fn list_tcp_services(&self) -> Vec<(u16, StreamService)> {
282 self.tcp_services
283 .iter()
284 .map(|e| (*e.key(), e.value().clone()))
285 .collect()
286 }
287
288 #[must_use]
290 pub fn list_udp_services(&self) -> Vec<(u16, StreamService)> {
291 self.udp_services
292 .iter()
293 .map(|e| (*e.key(), e.value().clone()))
294 .collect()
295 }
296
297 #[must_use]
306 pub fn spawn_health_checker(
307 self: &Arc<Self>,
308 interval: Duration,
309 timeout: Duration,
310 ) -> tokio::task::JoinHandle<()> {
311 let registry = Arc::clone(self);
312
313 tokio::spawn(async move {
314 let mut ticker = tokio::time::interval(interval);
315 ticker.tick().await;
317
318 loop {
319 ticker.tick().await;
320
321 for entry in ®istry.tcp_services {
323 let service = entry.value().clone();
324 let backends = service.backends.clone();
325
326 for addr in backends {
327 let svc = service.clone();
328 let probe_timeout = timeout;
329
330 tokio::spawn(async move {
332 let result = tokio::time::timeout(
333 probe_timeout,
334 tokio::net::TcpStream::connect(addr),
335 )
336 .await;
337
338 let health = match result {
339 Ok(Ok(_stream)) => BackendHealth::Healthy,
340 Ok(Err(e)) => {
341 tracing::debug!(
342 service = %svc.name,
343 backend = %addr,
344 error = %e,
345 "TCP health check failed (connect error)"
346 );
347 BackendHealth::Unhealthy
348 }
349 Err(_) => {
350 tracing::debug!(
351 service = %svc.name,
352 backend = %addr,
353 "TCP health check failed (timeout)"
354 );
355 BackendHealth::Unhealthy
356 }
357 };
358
359 svc.set_backend_health(addr, health).await;
360 });
361 }
362 }
363 }
364 })
365 }
366}