Skip to main content

zlayer_agent/
proxy_manager.rs

1//! Proxy management for agent-controlled services
2//!
3//! This module provides the `ProxyManager` struct that integrates the proxy crate
4//! with the agent's service management. It handles:
5//! - Managing proxy routes based on `ServiceSpec` endpoints (HTTP/HTTPS/WebSocket)
6//! - Managing L4 stream proxy listeners (TCP/UDP)
7//! - Tracking and updating backend servers for load balancing
8//! - Coordinating proxy server lifecycle
9
10use crate::error::Result;
11use std::collections::{HashMap, HashSet};
12use std::net::{IpAddr, Ipv4Addr, SocketAddr};
13use std::sync::atomic::{AtomicU64, Ordering};
14use std::sync::Arc;
15use std::time::Duration;
16use tokio::sync::RwLock;
17use tracing::{debug, info, warn};
18use zlayer_proxy::{
19    load_existing_certs_into_resolver, CertManager, LbStrategy, LoadBalancer, NetworkPolicyChecker,
20    ProxyConfig, ProxyServer, RouteEntry, ServiceRegistry, SniCertResolver, StreamRegistry,
21    TcpStreamService, UdpStreamService,
22};
23use zlayer_spec::{ExposeType, Protocol, ServiceSpec};
24
25/// Configuration for the `ProxyManager`
26#[derive(Debug, Clone)]
27pub struct ProxyManagerConfig {
28    /// HTTP bind address
29    pub http_addr: SocketAddr,
30    /// HTTPS bind address (optional)
31    pub https_addr: Option<SocketAddr>,
32    /// Whether to enable HTTP/2
33    pub http2_enabled: bool,
34}
35
36impl Default for ProxyManagerConfig {
37    fn default() -> Self {
38        Self {
39            http_addr: "0.0.0.0:80".parse().unwrap(),
40            https_addr: None,
41            http2_enabled: true,
42        }
43    }
44}
45
46impl ProxyManagerConfig {
47    /// Create a new configuration with the specified HTTP address
48    #[must_use]
49    pub fn new(http_addr: SocketAddr) -> Self {
50        Self {
51            http_addr,
52            https_addr: None,
53            http2_enabled: true,
54        }
55    }
56
57    /// Set the HTTPS address
58    #[must_use]
59    pub fn with_https(mut self, addr: SocketAddr) -> Self {
60        self.https_addr = Some(addr);
61        self
62    }
63
64    /// Set HTTP/2 support
65    #[must_use]
66    pub fn with_http2(mut self, enabled: bool) -> Self {
67        self.http2_enabled = enabled;
68        self
69    }
70}
71
72/// Per-service tracking information for cleanup purposes.
73#[derive(Debug, Clone)]
74struct ServiceTracking {
75    /// Endpoint names (retained for Debug output and future introspection)
76    #[allow(dead_code)]
77    endpoint_names: Vec<String>,
78    /// TCP ports owned by this service
79    tcp_ports: Vec<u16>,
80    /// UDP ports owned by this service
81    udp_ports: Vec<u16>,
82    /// HTTP/HTTPS/WebSocket ports owned by this service
83    http_ports: Vec<u16>,
84}
85
86/// Manages proxy routing for agent-controlled services
87///
88/// The `ProxyManager` coordinates between the agent's service lifecycle and
89/// the proxy crate's routing/load balancing infrastructure. It supports:
90///
91/// - **HTTP/HTTPS/WebSocket (L7)**: Multiple port listeners sharing the same
92///   `ServiceRegistry` for request matching and load balancing.
93/// - **TCP/UDP (L4)**: Standalone stream proxy listeners that forward raw
94///   connections/datagrams to backends via the `StreamRegistry`.
95pub struct ProxyManager {
96    /// Configuration
97    config: ProxyManagerConfig,
98    /// Shared service registry for HTTP request matching and backend management
99    registry: Arc<ServiceRegistry>,
100    /// Load balancer for health-aware backend selection
101    load_balancer: Arc<LoadBalancer>,
102    /// Per-port HTTP proxy server handles
103    servers: RwLock<HashMap<u16, Arc<ProxyServer>>>,
104    /// Tracked services and their endpoints (includes port ownership for cleanup)
105    services: RwLock<HashMap<String, ServiceTracking>>,
106    /// Stream registry for L4 TCP/UDP proxy routing
107    stream_registry: Option<Arc<StreamRegistry>>,
108    /// Certificate manager for TLS
109    cert_manager: Option<Arc<CertManager>>,
110    /// Ports with active TCP stream listeners (to avoid double-binding)
111    tcp_listeners: RwLock<HashSet<u16>>,
112    /// Ports with active UDP stream listeners (to avoid double-binding)
113    udp_listeners: RwLock<HashSet<u16>>,
114    /// Number of active proxy connections (for graceful drain on shutdown)
115    active_connections: Arc<AtomicU64>,
116    /// Optional network policy checker for access control enforcement
117    network_policy_checker: Option<NetworkPolicyChecker>,
118}
119
120impl ProxyManager {
121    /// Create a new `ProxyManager` with the given configuration, service registry,
122    /// and optional certificate manager.
123    pub fn new(
124        config: ProxyManagerConfig,
125        registry: Arc<ServiceRegistry>,
126        cert_manager: Option<Arc<CertManager>>,
127    ) -> Self {
128        let load_balancer = Arc::new(LoadBalancer::new());
129
130        Self {
131            config,
132            registry,
133            load_balancer,
134            servers: RwLock::new(HashMap::new()),
135            services: RwLock::new(HashMap::new()),
136            stream_registry: None,
137            cert_manager,
138            tcp_listeners: RwLock::new(HashSet::new()),
139            udp_listeners: RwLock::new(HashSet::new()),
140            active_connections: Arc::new(AtomicU64::new(0)),
141            network_policy_checker: None,
142        }
143    }
144
145    /// Get a reference to the service registry
146    pub fn registry(&self) -> Arc<ServiceRegistry> {
147        self.registry.clone()
148    }
149
150    /// Get a reference to the load balancer
151    pub fn load_balancer(&self) -> Arc<LoadBalancer> {
152        self.load_balancer.clone()
153    }
154
155    /// Get the number of currently active proxy connections.
156    pub fn active_connections(&self) -> u64 {
157        self.active_connections.load(Ordering::Relaxed)
158    }
159
160    /// Get a reference to the certificate manager (if configured)
161    pub fn cert_manager(&self) -> Option<&Arc<CertManager>> {
162        self.cert_manager.as_ref()
163    }
164
165    /// Set the stream registry for L4 proxy integration (TCP/UDP)
166    pub fn set_stream_registry(&mut self, registry: Arc<StreamRegistry>) {
167        self.stream_registry = Some(registry);
168    }
169
170    /// Builder pattern: add stream registry for L4 proxy integration
171    #[must_use]
172    pub fn with_stream_registry(mut self, registry: Arc<StreamRegistry>) -> Self {
173        self.stream_registry = Some(registry);
174        self
175    }
176
177    /// Get the stream registry (if configured)
178    pub fn stream_registry(&self) -> Option<&Arc<StreamRegistry>> {
179        self.stream_registry.as_ref()
180    }
181
182    /// Set the network policy checker for access control enforcement
183    pub fn set_network_policy_checker(&mut self, checker: NetworkPolicyChecker) {
184        self.network_policy_checker = Some(checker);
185    }
186
187    /// Builder pattern: add network policy checker for access control enforcement
188    #[must_use]
189    pub fn with_network_policy_checker(mut self, checker: NetworkPolicyChecker) -> Self {
190        self.network_policy_checker = Some(checker);
191        self
192    }
193
194    /// Start listening on a specific port bound to the given address.
195    ///
196    /// If already listening on this port, skip.
197    /// All port listeners share the same `ServiceRegistry` for request matching.
198    ///
199    /// # Errors
200    /// Returns an error if the proxy server cannot be started.
201    pub async fn listen_on(&self, port: u16, bind_ip: IpAddr) -> Result<()> {
202        let mut servers = self.servers.write().await;
203
204        if servers.contains_key(&port) {
205            debug!(port = port, "Already listening on port");
206            return Ok(());
207        }
208
209        let addr = SocketAddr::new(bind_ip, port);
210        let mut proxy_config = ProxyConfig::default();
211        proxy_config.server.http_addr = addr;
212        proxy_config.server.http2_enabled = self.config.http2_enabled;
213
214        let mut server = ProxyServer::with_registry(
215            proxy_config,
216            self.registry.clone(),
217            self.load_balancer.clone(),
218        );
219        if let Some(ref checker) = self.network_policy_checker {
220            server = server.with_network_policy_checker(checker.clone());
221        }
222        let server = Arc::new(server);
223
224        info!(port = port, bind = %addr, "Proxy listening on port");
225
226        let server_clone = server.clone();
227        tokio::spawn(async move {
228            if let Err(e) = server_clone.run().await {
229                tracing::error!(port = port, error = %e, "Proxy server error on port");
230            }
231        });
232
233        servers.insert(port, server);
234        Ok(())
235    }
236
237    /// Start an HTTPS listener on the given port using `SniCertResolver` for dynamic cert selection.
238    ///
239    /// If already listening on this port, skip.
240    /// Requires a `CertManager` to be configured; logs a warning and returns `Ok(())` if not.
241    ///
242    /// # Errors
243    /// Returns an error if the HTTPS proxy server cannot be started.
244    pub async fn listen_on_tls(&self, port: u16, bind_ip: IpAddr) -> Result<()> {
245        let mut servers = self.servers.write().await;
246
247        if servers.contains_key(&port) {
248            debug!(port = port, "Already listening on port (TLS)");
249            return Ok(());
250        }
251
252        let Some(cert_manager) = &self.cert_manager else {
253            warn!(
254                port = port,
255                "Cannot start TLS listener: no CertManager configured"
256            );
257            return Ok(());
258        };
259
260        // Create SniCertResolver and load existing certs
261        let sni_resolver = Arc::new(SniCertResolver::new());
262
263        // Load existing certificates (best-effort; log warnings on failure)
264        let _ = load_existing_certs_into_resolver(cert_manager, &sni_resolver).await;
265
266        let addr = SocketAddr::new(bind_ip, port);
267        let mut proxy_config = ProxyConfig::default();
268        proxy_config.server.https_addr = addr;
269
270        let mut server = ProxyServer::with_tls_resolver(
271            proxy_config,
272            self.registry.clone(),
273            self.load_balancer.clone(),
274            sni_resolver,
275        )
276        .with_cert_manager(Arc::clone(cert_manager));
277        if let Some(ref checker) = self.network_policy_checker {
278            server = server.with_network_policy_checker(checker.clone());
279        }
280        let server = Arc::new(server);
281
282        info!(port = port, bind = %addr, "HTTPS proxy listening on port");
283
284        let server_clone = server.clone();
285        tokio::spawn(async move {
286            if let Err(e) = server_clone.run_https().await {
287                tracing::error!(port = port, error = %e, "HTTPS proxy server error");
288            }
289        });
290
291        servers.insert(port, server);
292        Ok(())
293    }
294
295    /// Stop all proxy servers on all ports.
296    ///
297    /// After signalling each server to shut down, waits up to 30 seconds for
298    /// active connections to drain before returning.
299    pub async fn stop(&self) {
300        let mut servers = self.servers.write().await;
301        for (port, server) in servers.drain() {
302            info!(port = port, "Stopping proxy on port");
303            server.shutdown();
304        }
305
306        // Wait up to 30s for active connections to drain
307        let deadline = tokio::time::Instant::now() + Duration::from_secs(30);
308        while self.active_connections.load(Ordering::Relaxed) > 0 {
309            if tokio::time::Instant::now() >= deadline {
310                let remaining = self.active_connections.load(Ordering::Relaxed);
311                warn!(
312                    remaining = remaining,
313                    "Drain timeout reached, forcing shutdown"
314                );
315                break;
316            }
317            tokio::time::sleep(Duration::from_millis(100)).await;
318        }
319
320        info!("All proxy servers stopped");
321    }
322
323    /// Remove and shut down the listener on a specific port.
324    pub async fn unbind(&self, port: u16) {
325        let mut servers = self.servers.write().await;
326        if let Some(server) = servers.remove(&port) {
327            info!(port = port, "Unbinding proxy from port");
328            server.shutdown();
329        }
330    }
331
332    /// Scan a service's endpoints and ensure the proxy is listening on all
333    /// required ports.
334    ///
335    /// - **HTTP/HTTPS/WebSocket** endpoints start an HTTP proxy listener.
336    /// - **TCP** endpoints bind a `TcpListener` and spawn a `TcpStreamService`.
337    /// - **UDP** endpoints bind a `UdpSocket` and spawn a `UdpStreamService`.
338    ///
339    /// Bind address is determined by the `expose` type:
340    /// - **Public** endpoints bind to `0.0.0.0` (all interfaces).
341    /// - **Internal** endpoints bind to the overlay IP so they are only
342    ///   reachable from within the overlay network.  If no overlay is
343    ///   available, internal endpoints bind to `127.0.0.1` (localhost only).
344    ///
345    /// # Errors
346    /// Returns an error if an HTTP/HTTPS listener cannot be started.
347    pub async fn ensure_ports_for_service(
348        &self,
349        spec: &ServiceSpec,
350        overlay_ip: Option<IpAddr>,
351    ) -> Result<()> {
352        for endpoint in &spec.endpoints {
353            let bind_ip = match endpoint.expose {
354                ExposeType::Public => IpAddr::V4(Ipv4Addr::UNSPECIFIED), // 0.0.0.0
355                ExposeType::Internal => {
356                    // Prefer overlay IP; fall back to loopback if overlay is unavailable.
357                    let ip = overlay_ip.unwrap_or(IpAddr::V4(Ipv4Addr::LOCALHOST));
358                    if overlay_ip.is_none() {
359                        warn!(
360                            endpoint = %endpoint.name,
361                            port = endpoint.port,
362                            "No overlay IP available for internal endpoint; binding to 127.0.0.1"
363                        );
364                    }
365                    ip
366                }
367            };
368
369            match endpoint.protocol {
370                Protocol::Https => {
371                    // L7 TLS: start HTTPS proxy listener with SNI cert resolution
372                    self.listen_on_tls(endpoint.port, bind_ip).await?;
373                }
374                Protocol::Http | Protocol::Websocket => {
375                    // L7: start HTTP proxy listener
376                    self.listen_on(endpoint.port, bind_ip).await?;
377                }
378                Protocol::Tcp => {
379                    // L4 TCP: bind listener and spawn TcpStreamService
380                    self.ensure_tcp_listener(endpoint.port, bind_ip).await;
381                }
382                Protocol::Udp => {
383                    // L4 UDP: bind socket and spawn UdpStreamService
384                    self.ensure_udp_listener(endpoint.port, bind_ip).await;
385                }
386            }
387        }
388        Ok(())
389    }
390
391    /// Ensure a TCP stream listener is running on the given port.
392    ///
393    /// If a listener is already active on this port, this is a no-op.
394    /// Requires `stream_registry` to be configured; logs a warning if not.
395    async fn ensure_tcp_listener(&self, port: u16, bind_ip: IpAddr) {
396        // Check if already listening
397        {
398            let listeners = self.tcp_listeners.read().await;
399            if listeners.contains(&port) {
400                debug!(port = port, "TCP stream listener already active");
401                return;
402            }
403        }
404
405        let registry = if let Some(r) = &self.stream_registry {
406            Arc::clone(r)
407        } else {
408            warn!(
409                port = port,
410                "Cannot start TCP listener: StreamRegistry not configured"
411            );
412            return;
413        };
414
415        let addr = SocketAddr::new(bind_ip, port);
416        let listener = match tokio::net::TcpListener::bind(addr).await {
417            Ok(l) => l,
418            Err(e) => {
419                warn!(
420                    port = port,
421                    bind = %addr,
422                    error = %e,
423                    "Failed to bind TCP stream listener, continuing"
424                );
425                return;
426            }
427        };
428
429        // Mark as active before spawning
430        {
431            let mut listeners = self.tcp_listeners.write().await;
432            listeners.insert(port);
433        }
434
435        let tcp_service = Arc::new(TcpStreamService::new(registry, port));
436        tokio::spawn(async move {
437            tcp_service.serve(listener).await;
438        });
439
440        info!(port = port, bind = %addr, "TCP stream proxy listening");
441    }
442
443    /// Ensure a UDP stream listener is running on the given port.
444    ///
445    /// If a listener is already active on this port, this is a no-op.
446    /// Requires `stream_registry` to be configured; logs a warning if not.
447    async fn ensure_udp_listener(&self, port: u16, bind_ip: IpAddr) {
448        // Check if already listening
449        {
450            let listeners = self.udp_listeners.read().await;
451            if listeners.contains(&port) {
452                debug!(port = port, "UDP stream listener already active");
453                return;
454            }
455        }
456
457        let registry = if let Some(r) = &self.stream_registry {
458            Arc::clone(r)
459        } else {
460            warn!(
461                port = port,
462                "Cannot start UDP listener: StreamRegistry not configured"
463            );
464            return;
465        };
466
467        let addr = SocketAddr::new(bind_ip, port);
468        let socket = match tokio::net::UdpSocket::bind(addr).await {
469            Ok(s) => s,
470            Err(e) => {
471                warn!(
472                    port = port,
473                    bind = %addr,
474                    error = %e,
475                    "Failed to bind UDP stream listener, continuing"
476                );
477                return;
478            }
479        };
480
481        // Mark as active before spawning
482        {
483            let mut listeners = self.udp_listeners.write().await;
484            listeners.insert(port);
485        }
486
487        let udp_service = Arc::new(UdpStreamService::new(registry, port, None));
488        tokio::spawn(async move {
489            if let Err(e) = udp_service.serve(socket).await {
490                tracing::error!(
491                    port = port,
492                    error = %e,
493                    "UDP stream proxy service failed"
494                );
495            }
496        });
497
498        info!(port = port, bind = %addr, "UDP stream proxy listening");
499    }
500
501    /// Add routes for a service based on its specification
502    ///
503    /// This creates proxy routes for each endpoint defined in the `ServiceSpec`.
504    /// HTTP/HTTPS/WebSocket endpoints get L7 routes via the `ServiceRegistry`.
505    /// TCP/UDP endpoints are tracked but their L4 registration is handled
506    /// by the `ServiceManager::register_service_routes()` method.
507    pub async fn add_service(&self, name: &str, spec: &ServiceSpec) {
508        let mut services = self.services.write().await;
509
510        // Track which endpoints and ports we're adding
511        let mut endpoint_names = Vec::new();
512        let mut tcp_ports = Vec::new();
513        let mut udp_ports = Vec::new();
514        let mut http_ports = Vec::new();
515
516        for endpoint in &spec.endpoints {
517            match endpoint.protocol {
518                Protocol::Http | Protocol::Https | Protocol::Websocket => {
519                    // L7: register route in the ServiceRegistry
520                    let entry = RouteEntry::from_endpoint(name, endpoint);
521                    self.registry.register(entry).await;
522                    http_ports.push(endpoint.port);
523
524                    info!(
525                        service = name,
526                        endpoint = %endpoint.name,
527                        protocol = ?endpoint.protocol,
528                        path = ?endpoint.path,
529                        expose = ?endpoint.expose,
530                        "Added HTTP proxy route for service"
531                    );
532                }
533                Protocol::Tcp => {
534                    tcp_ports.push(endpoint.port);
535                    info!(
536                        service = name,
537                        endpoint = %endpoint.name,
538                        protocol = ?endpoint.protocol,
539                        port = endpoint.port,
540                        expose = ?endpoint.expose,
541                        "Tracking TCP stream endpoint for service"
542                    );
543                }
544                Protocol::Udp => {
545                    udp_ports.push(endpoint.port);
546                    info!(
547                        service = name,
548                        endpoint = %endpoint.name,
549                        protocol = ?endpoint.protocol,
550                        port = endpoint.port,
551                        expose = ?endpoint.expose,
552                        "Tracking UDP stream endpoint for service"
553                    );
554                }
555            }
556
557            endpoint_names.push(endpoint.name.clone());
558        }
559
560        // Register the service in the load balancer (starts with no backends)
561        self.load_balancer
562            .register(name, vec![], LbStrategy::RoundRobin);
563
564        services.insert(
565            name.to_string(),
566            ServiceTracking {
567                endpoint_names,
568                tcp_ports,
569                udp_ports,
570                http_ports,
571            },
572        );
573    }
574
575    /// Remove all routes, L4 listeners, and HTTP server handles for a service.
576    ///
577    /// This performs a full cleanup of all proxy resources associated with the
578    /// service:
579    /// - Removes L7 (HTTP/HTTPS/WebSocket) routes from the `ServiceRegistry`
580    /// - Unregisters TCP/UDP stream services from the `StreamRegistry`
581    /// - Removes port tracking for TCP/UDP listeners
582    /// - Shuts down HTTP proxy server handles that were exclusively owned by
583    ///   this service (only if no other service uses the same port)
584    pub async fn remove_service(&self, name: &str) {
585        let mut services = self.services.write().await;
586
587        if let Some(tracking) = services.remove(name) {
588            // 1. Remove L7 routes from the ServiceRegistry
589            self.registry.unregister_service(name).await;
590
591            // 1b. Remove from the load balancer
592            self.load_balancer.unregister(name);
593
594            // 2. Unregister TCP stream services and clear port tracking
595            if !tracking.tcp_ports.is_empty() {
596                let mut tcp_set = self.tcp_listeners.write().await;
597                for port in &tracking.tcp_ports {
598                    if let Some(registry) = &self.stream_registry {
599                        let _ = registry.unregister_tcp(*port);
600                    }
601                    tcp_set.remove(port);
602                    debug!(service = name, port = port, "Removed TCP listener tracking");
603                }
604            }
605
606            // 3. Unregister UDP stream services and clear port tracking
607            if !tracking.udp_ports.is_empty() {
608                let mut udp_set = self.udp_listeners.write().await;
609                for port in &tracking.udp_ports {
610                    if let Some(registry) = &self.stream_registry {
611                        let _ = registry.unregister_udp(*port);
612                    }
613                    udp_set.remove(port);
614                    debug!(service = name, port = port, "Removed UDP listener tracking");
615                }
616            }
617
618            // 4. Shut down HTTP proxy servers on ports exclusively owned by
619            //    this service (skip ports still used by other services)
620            if !tracking.http_ports.is_empty() {
621                let ports_still_in_use: HashSet<u16> = services
622                    .values()
623                    .flat_map(|t| t.http_ports.iter().copied())
624                    .collect();
625
626                let mut servers = self.servers.write().await;
627                for port in &tracking.http_ports {
628                    if !ports_still_in_use.contains(port) {
629                        if let Some(server) = servers.remove(port) {
630                            server.shutdown();
631                            info!(
632                                service = name,
633                                port = port,
634                                "Shut down HTTP proxy server (no remaining services on port)"
635                            );
636                        }
637                    }
638                }
639            }
640
641            info!(service = name, "Removed all proxy resources for service");
642        }
643    }
644
645    /// Add a single backend to a service
646    pub async fn add_backend(&self, service: &str, addr: SocketAddr) {
647        self.registry.add_backend(service, addr).await;
648        self.load_balancer.add_backend(service, addr);
649        info!(service = service, backend = %addr, "Registered backend with proxy");
650    }
651
652    /// Remove a backend from a service
653    pub async fn remove_backend(&self, service: &str, addr: SocketAddr) {
654        self.registry.remove_backend(service, addr).await;
655        self.load_balancer.remove_backend(service, &addr);
656        debug!(service = service, backend = %addr, "Removed backend from service");
657    }
658
659    /// Update the health status of a backend in the load balancer.
660    ///
661    /// Delegates to [`LoadBalancer::mark_health`] so that unhealthy backends
662    /// are skipped during selection.
663    #[allow(clippy::unused_async)]
664    pub async fn update_backend_health(&self, service: &str, addr: SocketAddr, healthy: bool) {
665        self.load_balancer.mark_health(service, &addr, healthy);
666        debug!(
667            service = service,
668            backend = %addr,
669            healthy = healthy,
670            "Updated backend health in load balancer"
671        );
672    }
673
674    /// Update the backends for a service
675    ///
676    /// This replaces all backends for the given service with the provided list.
677    /// Each backend should be the address where the service replica is listening.
678    pub async fn update_backends(&self, service: &str, addrs: Vec<SocketAddr>) {
679        self.registry.update_backends(service, addrs.clone()).await;
680        self.load_balancer.update_backends(service, addrs);
681        debug!(service = service, "Updated backends for service");
682    }
683
684    /// Get the number of registered routes
685    pub async fn route_count(&self) -> usize {
686        self.registry.route_count().await
687    }
688
689    /// Get the list of registered service names
690    pub async fn list_services(&self) -> Vec<String> {
691        self.services.read().await.keys().cloned().collect()
692    }
693
694    /// Check if a service has any registered endpoints
695    pub async fn has_service(&self, name: &str) -> bool {
696        self.services.read().await.contains_key(name)
697    }
698}
699
700#[cfg(test)]
701mod tests {
702    use super::*;
703
704    fn mock_service_spec_with_endpoints() -> ServiceSpec {
705        use zlayer_spec::*;
706        serde_yaml::from_str::<DeploymentSpec>(
707            r"
708version: v1
709deployment: test
710services:
711  test:
712    rtype: service
713    image:
714      name: test:latest
715    endpoints:
716      - name: http
717        protocol: http
718        port: 8080
719        path: /api
720        expose: public
721      - name: websocket
722        protocol: websocket
723        port: 8081
724        path: /ws
725        expose: internal
726",
727        )
728        .unwrap()
729        .services
730        .remove("test")
731        .unwrap()
732    }
733
734    fn mock_service_spec_tcp_only() -> ServiceSpec {
735        mock_service_spec_tcp_only_port(9000)
736    }
737
738    fn mock_service_spec_tcp_only_port(port: u16) -> ServiceSpec {
739        use zlayer_spec::*;
740        let yaml = format!(
741            "
742version: v1
743deployment: test
744services:
745  test:
746    rtype: service
747    image:
748      name: test:latest
749    endpoints:
750      - name: grpc
751        protocol: tcp
752        port: {port}
753"
754        );
755        serde_yaml::from_str::<DeploymentSpec>(&yaml)
756            .unwrap()
757            .services
758            .remove("test")
759            .unwrap()
760    }
761
762    /// Reserve an unused localhost TCP port by binding a listener on `:0`,
763    /// reading the assigned port, and dropping the listener.
764    ///
765    /// There is an inherent race between dropping the listener and the test
766    /// re-binding the port, but this is dramatically more reliable than
767    /// hard-coding a port (e.g., 9000) which is commonly in use on dev
768    /// machines (php-fpm, the running zlayer daemon, etc.).
769    fn reserve_free_tcp_port() -> u16 {
770        let listener =
771            std::net::TcpListener::bind("127.0.0.1:0").expect("failed to bind ephemeral test port");
772        listener.local_addr().unwrap().port()
773    }
774
775    #[tokio::test]
776    async fn test_proxy_manager_new() {
777        let config = ProxyManagerConfig::default();
778        let registry = Arc::new(ServiceRegistry::new());
779        let manager = ProxyManager::new(config, registry, None);
780
781        assert_eq!(manager.route_count().await, 0);
782        assert!(manager.list_services().await.is_empty());
783    }
784
785    #[tokio::test]
786    async fn test_add_service_with_http_endpoints() {
787        let config = ProxyManagerConfig::default();
788        let registry = Arc::new(ServiceRegistry::new());
789        let manager = ProxyManager::new(config, registry, None);
790
791        let spec = mock_service_spec_with_endpoints();
792        manager.add_service("api", &spec).await;
793
794        // Should have 2 routes (http and websocket)
795        assert_eq!(manager.route_count().await, 2);
796        assert!(manager.has_service("api").await);
797    }
798
799    #[tokio::test]
800    async fn test_tcp_endpoints_tracked_not_routed() {
801        let config = ProxyManagerConfig::default();
802        let registry = Arc::new(ServiceRegistry::new());
803        let manager = ProxyManager::new(config, registry, None);
804
805        let spec = mock_service_spec_tcp_only();
806        manager.add_service("grpc-service", &spec).await;
807
808        // TCP endpoints don't add HTTP routes
809        assert_eq!(manager.route_count().await, 0);
810        // But the service is still tracked with its endpoint name
811        assert!(manager.has_service("grpc-service").await);
812    }
813
814    #[tokio::test]
815    async fn test_remove_service() {
816        let config = ProxyManagerConfig::default();
817        let registry = Arc::new(ServiceRegistry::new());
818        let manager = ProxyManager::new(config, registry, None);
819
820        let spec = mock_service_spec_with_endpoints();
821        manager.add_service("api", &spec).await;
822        assert_eq!(manager.route_count().await, 2);
823
824        manager.remove_service("api").await;
825        assert_eq!(manager.route_count().await, 0);
826        assert!(!manager.has_service("api").await);
827    }
828
829    #[tokio::test]
830    async fn test_backend_management() {
831        let config = ProxyManagerConfig::default();
832        let registry = Arc::new(ServiceRegistry::new());
833        let manager = ProxyManager::new(config, registry.clone(), None);
834
835        let spec = mock_service_spec_with_endpoints();
836        manager.add_service("api", &spec).await;
837
838        // Add backends
839        let addr1: SocketAddr = "127.0.0.1:8080".parse().unwrap();
840        let addr2: SocketAddr = "127.0.0.1:8081".parse().unwrap();
841
842        manager.add_backend("api", addr1).await;
843        manager.add_backend("api", addr2).await;
844
845        // Verify backends via the registry's resolve
846        let resolved = registry.resolve(None, "/api").await.unwrap();
847        assert_eq!(resolved.backends.len(), 2);
848
849        // Remove a backend
850        manager.remove_backend("api", addr1).await;
851        let resolved = registry.resolve(None, "/api").await.unwrap();
852        assert_eq!(resolved.backends.len(), 1);
853    }
854
855    #[tokio::test]
856    async fn test_update_backends_replaces_all() {
857        let config = ProxyManagerConfig::default();
858        let registry = Arc::new(ServiceRegistry::new());
859        let manager = ProxyManager::new(config, registry.clone(), None);
860
861        let spec = mock_service_spec_with_endpoints();
862        manager.add_service("api", &spec).await;
863
864        // Add initial backend
865        let addr1: SocketAddr = "127.0.0.1:8080".parse().unwrap();
866        manager.add_backend("api", addr1).await;
867
868        // Update with new backends (replaces)
869        let new_backends: Vec<SocketAddr> = vec![
870            "127.0.0.1:9000".parse().unwrap(),
871            "127.0.0.1:9001".parse().unwrap(),
872            "127.0.0.1:9002".parse().unwrap(),
873        ];
874        manager.update_backends("api", new_backends).await;
875
876        let resolved = registry.resolve(None, "/api").await.unwrap();
877        assert_eq!(resolved.backends.len(), 3);
878    }
879
880    #[tokio::test]
881    async fn test_config_builder() {
882        let config = ProxyManagerConfig::new("0.0.0.0:8080".parse().unwrap())
883            .with_https("0.0.0.0:8443".parse().unwrap())
884            .with_http2(false);
885
886        assert_eq!(
887            config.http_addr,
888            "0.0.0.0:8080".parse::<SocketAddr>().unwrap()
889        );
890        assert_eq!(
891            config.https_addr,
892            Some("0.0.0.0:8443".parse::<SocketAddr>().unwrap())
893        );
894        assert!(!config.http2_enabled);
895    }
896
897    /// Test that `ensure_ports_for_service` correctly differentiates
898    /// Public (0.0.0.0) vs Internal (overlay or 127.0.0.1) bind addresses.
899    /// We can't actually bind in unit tests, but we verify the function
900    /// processes both endpoint types without error.
901    #[tokio::test]
902    async fn test_ensure_ports_differentiates_public_and_internal() {
903        let config = ProxyManagerConfig::default();
904        let registry = Arc::new(ServiceRegistry::new());
905        let manager = ProxyManager::new(config, registry, None);
906
907        let spec = mock_service_spec_with_endpoints();
908        // Passing None for overlay_ip: internal endpoints should fall back to 127.0.0.1
909        let result = manager.ensure_ports_for_service(&spec, None).await;
910        // listen_on may fail because we can't actually bind in tests, but
911        // the function itself should run without panicking.
912        let _ = result;
913    }
914
915    #[tokio::test]
916    async fn test_ensure_ports_with_overlay_ip() {
917        let config = ProxyManagerConfig::default();
918        let registry = Arc::new(ServiceRegistry::new());
919        let manager = ProxyManager::new(config, registry, None);
920
921        let spec = mock_service_spec_with_endpoints();
922        // Pass an overlay IP -- internal endpoints should bind there
923        let overlay_ip: IpAddr = "10.200.0.5".parse().unwrap();
924        let result = manager
925            .ensure_ports_for_service(&spec, Some(overlay_ip))
926            .await;
927        let _ = result;
928    }
929
930    fn mock_mixed_service_spec() -> ServiceSpec {
931        use zlayer_spec::*;
932        serde_yaml::from_str::<DeploymentSpec>(
933            r"
934version: v1
935deployment: test
936services:
937  mixed:
938    rtype: service
939    image:
940      name: test:latest
941    endpoints:
942      - name: http
943        protocol: http
944        port: 8080
945        path: /api
946        expose: public
947      - name: grpc
948        protocol: tcp
949        port: 9000
950        expose: public
951      - name: game
952        protocol: udp
953        port: 27015
954        expose: public
955",
956        )
957        .unwrap()
958        .services
959        .remove("mixed")
960        .unwrap()
961    }
962
963    #[tokio::test]
964    async fn test_add_mixed_service_tracks_all_endpoints() {
965        let config = ProxyManagerConfig::default();
966        let registry = Arc::new(ServiceRegistry::new());
967        let manager = ProxyManager::new(config, registry, None);
968
969        let spec = mock_mixed_service_spec();
970        manager.add_service("mixed", &spec).await;
971
972        // Only 1 HTTP route (tcp and udp don't add HTTP routes)
973        assert_eq!(manager.route_count().await, 1);
974        // Service is tracked
975        assert!(manager.has_service("mixed").await);
976    }
977
978    #[tokio::test]
979    async fn test_ensure_ports_tcp_with_stream_registry() {
980        use zlayer_proxy::StreamService;
981
982        let stream_registry = Arc::new(StreamRegistry::new());
983        let config = ProxyManagerConfig::default();
984        let registry = Arc::new(ServiceRegistry::new());
985        let mut manager = ProxyManager::new(config, registry, None);
986        manager.set_stream_registry(stream_registry.clone());
987
988        // Use an OS-assigned free port to avoid collisions with anything
989        // listening on the dev/CI box (e.g. php-fpm or a running zlayer
990        // daemon both default to port 9000 on 127.0.0.1).
991        let port = reserve_free_tcp_port();
992        let spec = mock_service_spec_tcp_only_port(port);
993
994        // Register the TCP service in the stream registry first (as ServiceManager does)
995        stream_registry.register_tcp(port, StreamService::new("grpc-service".to_string(), vec![]));
996
997        // Ensure ports -- should bind TCP listener
998        let result = manager.ensure_ports_for_service(&spec, None).await;
999        assert!(result.is_ok());
1000
1001        // Verify the TCP listener port is tracked
1002        let tcp_ports = manager.tcp_listeners.read().await;
1003        assert!(tcp_ports.contains(&port));
1004    }
1005
1006    #[tokio::test]
1007    async fn test_ensure_ports_tcp_without_stream_registry() {
1008        let config = ProxyManagerConfig::default();
1009        let registry = Arc::new(ServiceRegistry::new());
1010        let manager = ProxyManager::new(config, registry, None);
1011
1012        let spec = mock_service_spec_tcp_only();
1013
1014        // Without stream registry, ensure_ports should not fail, just warn
1015        let result = manager.ensure_ports_for_service(&spec, None).await;
1016        assert!(result.is_ok());
1017
1018        // No TCP listeners should be tracked
1019        let tcp_ports = manager.tcp_listeners.read().await;
1020        assert!(tcp_ports.is_empty());
1021    }
1022
1023    #[tokio::test]
1024    async fn test_stream_registry_setter() {
1025        let stream_registry = Arc::new(StreamRegistry::new());
1026        let config = ProxyManagerConfig::default();
1027        let registry = Arc::new(ServiceRegistry::new());
1028        let mut manager = ProxyManager::new(config, registry, None);
1029
1030        assert!(manager.stream_registry().is_none());
1031        manager.set_stream_registry(stream_registry.clone());
1032        assert!(manager.stream_registry().is_some());
1033    }
1034
1035    #[tokio::test]
1036    async fn test_registry_accessor() {
1037        let config = ProxyManagerConfig::default();
1038        let registry = Arc::new(ServiceRegistry::new());
1039        let manager = ProxyManager::new(config, registry.clone(), None);
1040
1041        // registry() should return the same Arc
1042        assert_eq!(Arc::as_ptr(&manager.registry()), Arc::as_ptr(&registry));
1043    }
1044}