Skip to main content

zlayer_proxy/stream/
registry.rs

1//! Stream service registry for L4 routing
2//!
3//! Maps listen ports to backend services for TCP and UDP proxying.
4//! Includes health-aware backend selection: unhealthy backends are
5//! skipped during round-robin selection, with a fallback to any
6//! backend if all are marked unhealthy.
7
8use dashmap::DashMap;
9use std::collections::HashMap;
10use std::net::SocketAddr;
11use std::sync::atomic::{AtomicUsize, Ordering};
12use std::sync::Arc;
13use std::time::Duration;
14use tokio::sync::RwLock;
15
16/// Health state of a stream backend
17#[derive(Debug, Clone, Copy, PartialEq, Eq)]
18pub enum BackendHealth {
19    /// Backend is reachable and accepting connections
20    Healthy,
21    /// Backend failed the last health probe
22    Unhealthy,
23    /// Health has not yet been determined (treated as healthy)
24    Unknown,
25}
26
27impl BackendHealth {
28    /// Returns `true` if the backend should be considered usable.
29    #[must_use]
30    pub fn is_usable(self) -> bool {
31        matches!(self, BackendHealth::Healthy | BackendHealth::Unknown)
32    }
33}
34
35/// A resolved stream service with backend addresses and health state
36#[derive(Clone, Debug)]
37pub struct StreamService {
38    /// Service name (for logging/metrics)
39    pub name: String,
40    /// Backend addresses for load balancing
41    pub backends: Vec<SocketAddr>,
42    /// Per-backend health state
43    health: Arc<RwLock<HashMap<SocketAddr, BackendHealth>>>,
44    /// Round-robin index for backend selection
45    rr_index: Arc<AtomicUsize>,
46}
47
48impl StreamService {
49    /// Create a new stream service
50    #[must_use]
51    pub fn new(name: String, backends: Vec<SocketAddr>) -> Self {
52        let health: HashMap<SocketAddr, BackendHealth> = backends
53            .iter()
54            .map(|addr| (*addr, BackendHealth::Unknown))
55            .collect();
56        Self {
57            name,
58            backends,
59            health: Arc::new(RwLock::new(health)),
60            rr_index: Arc::new(AtomicUsize::new(0)),
61        }
62    }
63
64    /// Select next backend using round-robin, skipping unhealthy backends.
65    ///
66    /// Tries up to `backends.len()` candidates. If all backends are unhealthy,
67    /// falls back to returning *any* backend (better than nothing).
68    #[must_use]
69    pub fn select_backend(&self) -> Option<SocketAddr> {
70        if self.backends.is_empty() {
71            return None;
72        }
73
74        let len = self.backends.len();
75        let start = self.rr_index.fetch_add(1, Ordering::Relaxed);
76
77        // Try to read health state without blocking; if the lock is held,
78        // just fall through to simple round-robin.
79        let health_guard = self.health.try_read();
80
81        if let Ok(health) = health_guard {
82            // First pass: find a healthy backend
83            for i in 0..len {
84                let idx = (start + i) % len;
85                let addr = self.backends[idx];
86                let status = health.get(&addr).copied().unwrap_or(BackendHealth::Unknown);
87                if status.is_usable() {
88                    return Some(addr);
89                }
90            }
91        }
92
93        // Fallback: all unhealthy or lock contention — use simple round-robin
94        Some(self.backends[start % len])
95    }
96
97    /// Update backend addresses (for scaling events).
98    ///
99    /// New backends start with `Unknown` health; removed backends are pruned
100    /// from the health map.
101    pub fn update_backends(&mut self, backends: Vec<SocketAddr>) {
102        // We need to block here since this is called from a &mut self context
103        // (inside DashMap::get_mut), so we can use blocking write.
104        let mut health = self
105            .health
106            .try_write()
107            .unwrap_or_else(|_| {
108                // In the extremely unlikely case of write contention, just proceed
109                // with a fresh health map.
110                tracing::warn!(service = %self.name, "Health map write contention during backend update");
111                // This should never actually happen since update_backends holds &mut self
112                unreachable!("update_backends requires exclusive access")
113            });
114
115        // Add new backends with Unknown health
116        for addr in &backends {
117            health.entry(*addr).or_insert(BackendHealth::Unknown);
118        }
119
120        // Remove backends that are no longer present
121        let backend_set: std::collections::HashSet<SocketAddr> = backends.iter().copied().collect();
122        health.retain(|addr, _| backend_set.contains(addr));
123
124        self.backends = backends;
125    }
126
127    /// Set the health status of a specific backend
128    pub async fn set_backend_health(&self, addr: SocketAddr, status: BackendHealth) {
129        let mut health = self.health.write().await;
130        if let Some(h) = health.get_mut(&addr) {
131            *h = status;
132        }
133    }
134
135    /// Get the health status of a specific backend
136    pub async fn get_backend_health(&self, addr: SocketAddr) -> BackendHealth {
137        let health = self.health.read().await;
138        health.get(&addr).copied().unwrap_or(BackendHealth::Unknown)
139    }
140
141    /// Get current backend count
142    #[must_use]
143    pub fn backend_count(&self) -> usize {
144        self.backends.len()
145    }
146
147    /// Get count of healthy (usable) backends
148    pub async fn healthy_count(&self) -> usize {
149        let health = self.health.read().await;
150        self.backends
151            .iter()
152            .filter(|addr| {
153                health
154                    .get(addr)
155                    .copied()
156                    .unwrap_or(BackendHealth::Unknown)
157                    .is_usable()
158            })
159            .count()
160    }
161}
162
163/// Registry for L4 stream services
164///
165/// Maps listen ports to services for both TCP and UDP protocols.
166#[derive(Default)]
167pub struct StreamRegistry {
168    /// TCP services by listen port
169    tcp_services: DashMap<u16, StreamService>,
170    /// UDP services by listen port
171    udp_services: DashMap<u16, StreamService>,
172}
173
174impl StreamRegistry {
175    /// Create a new empty registry
176    #[must_use]
177    pub fn new() -> Self {
178        Self::default()
179    }
180
181    /// Register a TCP service for a port
182    pub fn register_tcp(&self, port: u16, service: StreamService) {
183        tracing::debug!(
184            port = port,
185            service = %service.name,
186            backends = service.backend_count(),
187            "Registered TCP stream service"
188        );
189        self.tcp_services.insert(port, service);
190    }
191
192    /// Register a UDP service for a port
193    pub fn register_udp(&self, port: u16, service: StreamService) {
194        tracing::debug!(
195            port = port,
196            service = %service.name,
197            backends = service.backend_count(),
198            "Registered UDP stream service"
199        );
200        self.udp_services.insert(port, service);
201    }
202
203    /// Resolve TCP service for a port
204    #[must_use]
205    pub fn resolve_tcp(&self, port: u16) -> Option<StreamService> {
206        self.tcp_services.get(&port).map(|s| s.clone())
207    }
208
209    /// Resolve UDP service for a port
210    #[must_use]
211    pub fn resolve_udp(&self, port: u16) -> Option<StreamService> {
212        self.udp_services.get(&port).map(|s| s.clone())
213    }
214
215    /// Update backends for a TCP service
216    pub fn update_tcp_backends(&self, port: u16, backends: Vec<SocketAddr>) {
217        if let Some(mut service) = self.tcp_services.get_mut(&port) {
218            tracing::debug!(
219                port = port,
220                service = %service.name,
221                old_count = service.backend_count(),
222                new_count = backends.len(),
223                "Updating TCP backends"
224            );
225            service.update_backends(backends);
226        }
227    }
228
229    /// Update backends for a UDP service
230    pub fn update_udp_backends(&self, port: u16, backends: Vec<SocketAddr>) {
231        if let Some(mut service) = self.udp_services.get_mut(&port) {
232            tracing::debug!(
233                port = port,
234                service = %service.name,
235                old_count = service.backend_count(),
236                new_count = backends.len(),
237                "Updating UDP backends"
238            );
239            service.update_backends(backends);
240        }
241    }
242
243    /// Remove a TCP service
244    #[must_use]
245    pub fn unregister_tcp(&self, port: u16) -> Option<StreamService> {
246        self.tcp_services.remove(&port).map(|(_, s)| s)
247    }
248
249    /// Remove a UDP service
250    #[must_use]
251    pub fn unregister_udp(&self, port: u16) -> Option<StreamService> {
252        self.udp_services.remove(&port).map(|(_, s)| s)
253    }
254
255    /// Get count of registered TCP services
256    #[must_use]
257    pub fn tcp_count(&self) -> usize {
258        self.tcp_services.len()
259    }
260
261    /// Get count of registered UDP services
262    #[must_use]
263    pub fn udp_count(&self) -> usize {
264        self.udp_services.len()
265    }
266
267    /// List all registered TCP ports
268    #[must_use]
269    pub fn tcp_ports(&self) -> Vec<u16> {
270        self.tcp_services.iter().map(|e| *e.key()).collect()
271    }
272
273    /// List all registered UDP ports
274    #[must_use]
275    pub fn udp_ports(&self) -> Vec<u16> {
276        self.udp_services.iter().map(|e| *e.key()).collect()
277    }
278
279    /// List all registered TCP services with their listen ports.
280    #[must_use]
281    pub fn list_tcp_services(&self) -> Vec<(u16, StreamService)> {
282        self.tcp_services
283            .iter()
284            .map(|e| (*e.key(), e.value().clone()))
285            .collect()
286    }
287
288    /// List all registered UDP services with their listen ports.
289    #[must_use]
290    pub fn list_udp_services(&self) -> Vec<(u16, StreamService)> {
291        self.udp_services
292            .iter()
293            .map(|e| (*e.key(), e.value().clone()))
294            .collect()
295    }
296
297    /// Spawn a background health checker that periodically probes all
298    /// registered TCP backends with a connect-only health check.
299    ///
300    /// UDP backends are not probed (there is no reliable connectionless
301    /// health check). They remain `Unknown` and are always considered usable.
302    ///
303    /// The task runs every `interval` and uses `timeout` for each probe.
304    /// Returns a `JoinHandle` that can be used to cancel the checker.
305    #[must_use]
306    pub fn spawn_health_checker(
307        self: &Arc<Self>,
308        interval: Duration,
309        timeout: Duration,
310    ) -> tokio::task::JoinHandle<()> {
311        let registry = Arc::clone(self);
312
313        tokio::spawn(async move {
314            let mut ticker = tokio::time::interval(interval);
315            // Skip the first immediate tick
316            ticker.tick().await;
317
318            loop {
319                ticker.tick().await;
320
321                // Iterate all TCP services and probe each backend
322                for entry in &registry.tcp_services {
323                    let service = entry.value().clone();
324                    let backends = service.backends.clone();
325
326                    for addr in backends {
327                        let svc = service.clone();
328                        let probe_timeout = timeout;
329
330                        // Probe each backend concurrently
331                        tokio::spawn(async move {
332                            let result = tokio::time::timeout(
333                                probe_timeout,
334                                tokio::net::TcpStream::connect(addr),
335                            )
336                            .await;
337
338                            let health = match result {
339                                Ok(Ok(_stream)) => BackendHealth::Healthy,
340                                Ok(Err(e)) => {
341                                    tracing::debug!(
342                                        service = %svc.name,
343                                        backend = %addr,
344                                        error = %e,
345                                        "TCP health check failed (connect error)"
346                                    );
347                                    BackendHealth::Unhealthy
348                                }
349                                Err(_) => {
350                                    tracing::debug!(
351                                        service = %svc.name,
352                                        backend = %addr,
353                                        "TCP health check failed (timeout)"
354                                    );
355                                    BackendHealth::Unhealthy
356                                }
357                            };
358
359                            svc.set_backend_health(addr, health).await;
360                        });
361                    }
362                }
363            }
364        })
365    }
366}