1use crate::error::Result;
11use std::collections::{HashMap, HashSet};
12use std::net::{IpAddr, Ipv4Addr, SocketAddr};
13use std::sync::atomic::{AtomicU64, Ordering};
14use std::sync::Arc;
15use std::time::Duration;
16use tokio::sync::RwLock;
17use tracing::{debug, info, warn};
18use zlayer_proxy::{
19 endpoint_lb_key, load_existing_certs_into_resolver, CertManager, LbStrategy, LoadBalancer,
20 NetworkPolicyChecker, ProxyConfig, ProxyServer, RouteEntry, ServiceRegistry, SniCertResolver,
21 StreamRegistry, StreamService, TcpStreamService, UdpStreamService,
22};
23use zlayer_spec::{ExposeType, Protocol, ServiceSpec};
24
25#[derive(Debug, Clone)]
27pub struct ProxyManagerConfig {
28 pub http_addr: SocketAddr,
30 pub https_addr: Option<SocketAddr>,
32 pub http2_enabled: bool,
34}
35
36impl Default for ProxyManagerConfig {
37 fn default() -> Self {
38 Self {
39 http_addr: "0.0.0.0:80".parse().unwrap(),
40 https_addr: None,
41 http2_enabled: true,
42 }
43 }
44}
45
46impl ProxyManagerConfig {
47 #[must_use]
49 pub fn new(http_addr: SocketAddr) -> Self {
50 Self {
51 http_addr,
52 https_addr: None,
53 http2_enabled: true,
54 }
55 }
56
57 #[must_use]
59 pub fn with_https(mut self, addr: SocketAddr) -> Self {
60 self.https_addr = Some(addr);
61 self
62 }
63
64 #[must_use]
66 pub fn with_http2(mut self, enabled: bool) -> Self {
67 self.http2_enabled = enabled;
68 self
69 }
70}
71
72#[derive(Debug, Clone)]
74struct ServiceTracking {
75 endpoint_names: Vec<String>,
78 tcp_ports: Vec<u16>,
80 udp_ports: Vec<u16>,
82 http_ports: Vec<u16>,
84}
85
86pub struct ProxyManager {
96 config: ProxyManagerConfig,
98 registry: Arc<ServiceRegistry>,
100 load_balancer: Arc<LoadBalancer>,
102 servers: RwLock<HashMap<u16, Arc<ProxyServer>>>,
104 services: RwLock<HashMap<String, ServiceTracking>>,
106 stream_registry: Option<Arc<StreamRegistry>>,
108 cert_manager: Option<Arc<CertManager>>,
110 tcp_listeners: RwLock<HashSet<u16>>,
112 udp_listeners: RwLock<HashSet<u16>>,
114 active_connections: Arc<AtomicU64>,
116 network_policy_checker: Option<NetworkPolicyChecker>,
118 loopback_registry: Arc<StreamRegistry>,
128 loopback_tcp: RwLock<HashMap<u16, tokio::task::JoinHandle<()>>>,
132 loopback_udp: RwLock<HashMap<u16, tokio::task::JoinHandle<()>>>,
135 lb_health_checker: tokio::task::JoinHandle<()>,
144}
145
146impl Drop for ProxyManager {
147 fn drop(&mut self) {
148 self.lb_health_checker.abort();
149 }
150}
151
152impl ProxyManager {
153 pub fn new(
156 config: ProxyManagerConfig,
157 registry: Arc<ServiceRegistry>,
158 cert_manager: Option<Arc<CertManager>>,
159 ) -> Self {
160 let load_balancer = Arc::new(LoadBalancer::new());
161
162 let lb_health_checker =
168 load_balancer.spawn_health_checker(Duration::from_secs(5), Duration::from_secs(2));
169
170 Self {
171 config,
172 registry,
173 load_balancer,
174 servers: RwLock::new(HashMap::new()),
175 services: RwLock::new(HashMap::new()),
176 stream_registry: None,
177 cert_manager,
178 tcp_listeners: RwLock::new(HashSet::new()),
179 udp_listeners: RwLock::new(HashSet::new()),
180 active_connections: Arc::new(AtomicU64::new(0)),
181 network_policy_checker: None,
182 loopback_registry: Arc::new(StreamRegistry::new()),
183 loopback_tcp: RwLock::new(HashMap::new()),
184 loopback_udp: RwLock::new(HashMap::new()),
185 lb_health_checker,
186 }
187 }
188
189 pub fn registry(&self) -> Arc<ServiceRegistry> {
191 self.registry.clone()
192 }
193
194 pub fn load_balancer(&self) -> Arc<LoadBalancer> {
196 self.load_balancer.clone()
197 }
198
199 pub fn active_connections(&self) -> u64 {
201 self.active_connections.load(Ordering::Relaxed)
202 }
203
204 pub fn cert_manager(&self) -> Option<&Arc<CertManager>> {
206 self.cert_manager.as_ref()
207 }
208
209 pub fn set_stream_registry(&mut self, registry: Arc<StreamRegistry>) {
211 self.stream_registry = Some(registry);
212 }
213
214 #[must_use]
216 pub fn with_stream_registry(mut self, registry: Arc<StreamRegistry>) -> Self {
217 self.stream_registry = Some(registry);
218 self
219 }
220
221 pub fn stream_registry(&self) -> Option<&Arc<StreamRegistry>> {
223 self.stream_registry.as_ref()
224 }
225
226 pub fn set_network_policy_checker(&mut self, checker: NetworkPolicyChecker) {
228 self.network_policy_checker = Some(checker);
229 }
230
231 #[must_use]
233 pub fn with_network_policy_checker(mut self, checker: NetworkPolicyChecker) -> Self {
234 self.network_policy_checker = Some(checker);
235 self
236 }
237
238 pub async fn listen_on(&self, port: u16, bind_ip: IpAddr) -> Result<()> {
246 let mut servers = self.servers.write().await;
247
248 if servers.contains_key(&port) {
249 debug!(port = port, "Already listening on port");
250 return Ok(());
251 }
252
253 let addr = SocketAddr::new(bind_ip, port);
254 let mut proxy_config = ProxyConfig::default();
255 proxy_config.server.http_addr = addr;
256 proxy_config.server.http2_enabled = self.config.http2_enabled;
257
258 let mut server = ProxyServer::with_registry(
259 proxy_config,
260 self.registry.clone(),
261 self.load_balancer.clone(),
262 );
263 if let Some(ref checker) = self.network_policy_checker {
264 server = server.with_network_policy_checker(checker.clone());
265 }
266 let server = Arc::new(server);
267
268 info!(port = port, bind = %addr, "Proxy listening on port");
269
270 let server_clone = server.clone();
271 tokio::spawn(async move {
272 if let Err(e) = server_clone.run().await {
273 tracing::error!(port = port, error = %e, "Proxy server error on port");
274 }
275 });
276
277 servers.insert(port, server);
278 Ok(())
279 }
280
281 pub async fn listen_on_tls(&self, port: u16, bind_ip: IpAddr) -> Result<()> {
289 let mut servers = self.servers.write().await;
290
291 if servers.contains_key(&port) {
292 debug!(port = port, "Already listening on port (TLS)");
293 return Ok(());
294 }
295
296 let Some(cert_manager) = &self.cert_manager else {
297 warn!(
298 port = port,
299 "Cannot start TLS listener: no CertManager configured"
300 );
301 return Ok(());
302 };
303
304 let sni_resolver = Arc::new(SniCertResolver::new());
306
307 let _ = load_existing_certs_into_resolver(cert_manager, &sni_resolver).await;
309
310 let addr = SocketAddr::new(bind_ip, port);
311 let mut proxy_config = ProxyConfig::default();
312 proxy_config.server.https_addr = addr;
313
314 let mut server = ProxyServer::with_tls_resolver(
315 proxy_config,
316 self.registry.clone(),
317 self.load_balancer.clone(),
318 sni_resolver,
319 )
320 .with_cert_manager(Arc::clone(cert_manager));
321 if let Some(ref checker) = self.network_policy_checker {
322 server = server.with_network_policy_checker(checker.clone());
323 }
324 let server = Arc::new(server);
325
326 info!(port = port, bind = %addr, "HTTPS proxy listening on port");
327
328 let server_clone = server.clone();
329 tokio::spawn(async move {
330 if let Err(e) = server_clone.run_https().await {
331 tracing::error!(port = port, error = %e, "HTTPS proxy server error");
332 }
333 });
334
335 servers.insert(port, server);
336 Ok(())
337 }
338
339 pub async fn stop(&self) {
344 let mut servers = self.servers.write().await;
345 for (port, server) in servers.drain() {
346 info!(port = port, "Stopping proxy on port");
347 server.shutdown();
348 }
349
350 let deadline = tokio::time::Instant::now() + Duration::from_secs(30);
352 while self.active_connections.load(Ordering::Relaxed) > 0 {
353 if tokio::time::Instant::now() >= deadline {
354 let remaining = self.active_connections.load(Ordering::Relaxed);
355 warn!(
356 remaining = remaining,
357 "Drain timeout reached, forcing shutdown"
358 );
359 break;
360 }
361 tokio::time::sleep(Duration::from_millis(100)).await;
362 }
363
364 info!("All proxy servers stopped");
365 }
366
367 pub async fn unbind(&self, port: u16) {
369 let mut servers = self.servers.write().await;
370 if let Some(server) = servers.remove(&port) {
371 info!(port = port, "Unbinding proxy from port");
372 server.shutdown();
373 }
374 }
375
376 pub async fn ensure_ports_for_service(
392 &self,
393 spec: &ServiceSpec,
394 overlay_ip: Option<IpAddr>,
395 ) -> Result<()> {
396 for endpoint in &spec.endpoints {
397 let bind_ip = match endpoint.expose {
398 ExposeType::Public => IpAddr::V4(Ipv4Addr::UNSPECIFIED), ExposeType::Internal => {
400 let ip = overlay_ip.unwrap_or(IpAddr::V4(Ipv4Addr::LOCALHOST));
402 if overlay_ip.is_none() {
403 warn!(
404 endpoint = %endpoint.name,
405 port = endpoint.port,
406 "No overlay IP available for internal endpoint; binding to 127.0.0.1"
407 );
408 }
409 ip
410 }
411 };
412
413 match endpoint.protocol {
414 Protocol::Https => {
415 self.listen_on_tls(endpoint.port, bind_ip).await?;
417 }
418 Protocol::Http | Protocol::Websocket => {
419 self.listen_on(endpoint.port, bind_ip).await?;
421 }
422 Protocol::Tcp => {
423 self.ensure_tcp_listener(endpoint.port, bind_ip).await;
425 }
426 Protocol::Udp => {
427 self.ensure_udp_listener(endpoint.port, bind_ip).await;
429 }
430 }
431 }
432 Ok(())
433 }
434
435 async fn ensure_tcp_listener(&self, port: u16, bind_ip: IpAddr) {
440 {
442 let listeners = self.tcp_listeners.read().await;
443 if listeners.contains(&port) {
444 debug!(port = port, "TCP stream listener already active");
445 return;
446 }
447 }
448
449 let registry = if let Some(r) = &self.stream_registry {
450 Arc::clone(r)
451 } else {
452 warn!(
453 port = port,
454 "Cannot start TCP listener: StreamRegistry not configured"
455 );
456 return;
457 };
458
459 let addr = SocketAddr::new(bind_ip, port);
460 let listener = match tokio::net::TcpListener::bind(addr).await {
461 Ok(l) => l,
462 Err(e) => {
463 warn!(
464 port = port,
465 bind = %addr,
466 error = %e,
467 "Failed to bind TCP stream listener, continuing"
468 );
469 return;
470 }
471 };
472
473 {
475 let mut listeners = self.tcp_listeners.write().await;
476 listeners.insert(port);
477 }
478
479 let tcp_service = Arc::new(TcpStreamService::new(registry, port));
480 tokio::spawn(async move {
481 tcp_service.serve(listener).await;
482 });
483
484 info!(port = port, bind = %addr, "TCP stream proxy listening");
485 }
486
487 async fn ensure_udp_listener(&self, port: u16, bind_ip: IpAddr) {
492 {
494 let listeners = self.udp_listeners.read().await;
495 if listeners.contains(&port) {
496 debug!(port = port, "UDP stream listener already active");
497 return;
498 }
499 }
500
501 let registry = if let Some(r) = &self.stream_registry {
502 Arc::clone(r)
503 } else {
504 warn!(
505 port = port,
506 "Cannot start UDP listener: StreamRegistry not configured"
507 );
508 return;
509 };
510
511 let addr = SocketAddr::new(bind_ip, port);
512 let socket = match tokio::net::UdpSocket::bind(addr).await {
513 Ok(s) => s,
514 Err(e) => {
515 warn!(
516 port = port,
517 bind = %addr,
518 error = %e,
519 "Failed to bind UDP stream listener, continuing"
520 );
521 return;
522 }
523 };
524
525 {
527 let mut listeners = self.udp_listeners.write().await;
528 listeners.insert(port);
529 }
530
531 let udp_service = Arc::new(UdpStreamService::new(registry, port, None));
532 tokio::spawn(async move {
533 if let Err(e) = udp_service.serve(socket).await {
534 tracing::error!(
535 port = port,
536 error = %e,
537 "UDP stream proxy service failed"
538 );
539 }
540 });
541
542 info!(port = port, bind = %addr, "UDP stream proxy listening");
543 }
544
545 pub async fn publish_loopback_for_container(
576 &self,
577 service_name: &str,
578 spec: &ServiceSpec,
579 container_ip: IpAddr,
580 port_override: Option<u16>,
581 ) {
582 for endpoint in &spec.endpoints {
583 if matches!(endpoint.expose, ExposeType::Public) {
585 continue;
586 }
587
588 let backend = SocketAddr::new(
589 container_ip,
590 port_override.unwrap_or_else(|| endpoint.target_port()),
591 );
592 let publish_port = endpoint.port;
593
594 match endpoint.protocol {
595 Protocol::Tcp | Protocol::Http | Protocol::Https | Protocol::Websocket => {
596 self.publish_loopback_tcp(service_name, publish_port, backend)
599 .await;
600 }
601 Protocol::Udp => {
602 self.publish_loopback_udp(service_name, publish_port, backend)
603 .await;
604 }
605 }
606 }
607 }
608
609 async fn publish_loopback_tcp(
612 &self,
613 service_name: &str,
614 publish_port: u16,
615 backend: SocketAddr,
616 ) {
617 if let Some(existing) = self.loopback_registry.resolve_tcp(publish_port) {
619 let mut backends = existing.backends;
620 if !backends.contains(&backend) {
621 backends.push(backend);
622 }
623 self.loopback_registry
624 .update_tcp_backends(publish_port, backends);
625 } else {
626 self.loopback_registry.register_tcp(
627 publish_port,
628 StreamService::new(service_name.to_string(), vec![backend]),
629 );
630 }
631
632 let mut listeners = self.loopback_tcp.write().await;
634 if listeners.contains_key(&publish_port) {
635 debug!(port = publish_port, "Loopback TCP listener already active");
636 return;
637 }
638
639 let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), publish_port);
640 let listener = match tokio::net::TcpListener::bind(addr).await {
641 Ok(l) => l,
642 Err(e) => {
643 warn!(
644 port = publish_port,
645 bind = %addr,
646 error = %e,
647 "Failed to bind loopback TCP listener, continuing"
648 );
649 return;
650 }
651 };
652
653 let tcp_service = Arc::new(TcpStreamService::new(
654 Arc::clone(&self.loopback_registry),
655 publish_port,
656 ));
657 let handle = tokio::spawn(async move {
658 tcp_service.serve(listener).await;
659 });
660 listeners.insert(publish_port, handle);
661 drop(listeners);
662
663 info!(
664 service = service_name,
665 port = publish_port,
666 bind = %addr,
667 backend = %backend,
668 "Published service port on node loopback (TCP)"
669 );
670 }
671
672 async fn publish_loopback_udp(
675 &self,
676 service_name: &str,
677 publish_port: u16,
678 backend: SocketAddr,
679 ) {
680 if let Some(existing) = self.loopback_registry.resolve_udp(publish_port) {
681 let mut backends = existing.backends;
682 if !backends.contains(&backend) {
683 backends.push(backend);
684 }
685 self.loopback_registry
686 .update_udp_backends(publish_port, backends);
687 } else {
688 self.loopback_registry.register_udp(
689 publish_port,
690 StreamService::new(service_name.to_string(), vec![backend]),
691 );
692 }
693
694 let mut listeners = self.loopback_udp.write().await;
695 if listeners.contains_key(&publish_port) {
696 debug!(port = publish_port, "Loopback UDP listener already active");
697 return;
698 }
699
700 let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), publish_port);
701 let socket = match tokio::net::UdpSocket::bind(addr).await {
702 Ok(s) => s,
703 Err(e) => {
704 warn!(
705 port = publish_port,
706 bind = %addr,
707 error = %e,
708 "Failed to bind loopback UDP listener, continuing"
709 );
710 return;
711 }
712 };
713
714 let udp_service = Arc::new(UdpStreamService::new(
715 Arc::clone(&self.loopback_registry),
716 publish_port,
717 None,
718 ));
719 let handle = tokio::spawn(async move {
720 if let Err(e) = udp_service.serve(socket).await {
721 tracing::error!(
722 port = publish_port,
723 error = %e,
724 "Loopback UDP stream proxy service failed"
725 );
726 }
727 });
728 listeners.insert(publish_port, handle);
729 drop(listeners);
730
731 info!(
732 service = service_name,
733 port = publish_port,
734 bind = %addr,
735 backend = %backend,
736 "Published service port on node loopback (UDP)"
737 );
738 }
739
740 pub async fn unpublish_loopback_for_container(
750 &self,
751 spec: &ServiceSpec,
752 container_ip: IpAddr,
753 port_override: Option<u16>,
754 ) {
755 for endpoint in &spec.endpoints {
756 if matches!(endpoint.expose, ExposeType::Public) {
757 continue;
758 }
759
760 let backend = SocketAddr::new(
761 container_ip,
762 port_override.unwrap_or_else(|| endpoint.target_port()),
763 );
764 let publish_port = endpoint.port;
765
766 match endpoint.protocol {
767 Protocol::Tcp | Protocol::Http | Protocol::Https | Protocol::Websocket => {
768 self.unpublish_loopback_tcp(publish_port, backend).await;
769 }
770 Protocol::Udp => {
771 self.unpublish_loopback_udp(publish_port, backend).await;
772 }
773 }
774 }
775 }
776
777 async fn unpublish_loopback_tcp(&self, publish_port: u16, backend: SocketAddr) {
780 let Some(existing) = self.loopback_registry.resolve_tcp(publish_port) else {
781 return;
782 };
783 let remaining: Vec<SocketAddr> = existing
784 .backends
785 .into_iter()
786 .filter(|b| *b != backend)
787 .collect();
788
789 if remaining.is_empty() {
790 let _ = self.loopback_registry.unregister_tcp(publish_port);
791 let mut listeners = self.loopback_tcp.write().await;
792 if let Some(handle) = listeners.remove(&publish_port) {
793 handle.abort();
794 }
795 debug!(
796 port = publish_port,
797 "Freed loopback TCP listener (no backends remain)"
798 );
799 } else {
800 self.loopback_registry
801 .update_tcp_backends(publish_port, remaining);
802 }
803 }
804
805 async fn unpublish_loopback_udp(&self, publish_port: u16, backend: SocketAddr) {
808 let Some(existing) = self.loopback_registry.resolve_udp(publish_port) else {
809 return;
810 };
811 let remaining: Vec<SocketAddr> = existing
812 .backends
813 .into_iter()
814 .filter(|b| *b != backend)
815 .collect();
816
817 if remaining.is_empty() {
818 let _ = self.loopback_registry.unregister_udp(publish_port);
819 let mut listeners = self.loopback_udp.write().await;
820 if let Some(handle) = listeners.remove(&publish_port) {
821 handle.abort();
822 }
823 debug!(
824 port = publish_port,
825 "Freed loopback UDP listener (no backends remain)"
826 );
827 } else {
828 self.loopback_registry
829 .update_udp_backends(publish_port, remaining);
830 }
831 }
832
833 pub async fn add_service(&self, name: &str, spec: &ServiceSpec) {
840 let mut services = self.services.write().await;
841
842 let mut endpoint_names = Vec::new();
844 let mut tcp_ports = Vec::new();
845 let mut udp_ports = Vec::new();
846 let mut http_ports = Vec::new();
847
848 for endpoint in &spec.endpoints {
849 match endpoint.protocol {
850 Protocol::Http | Protocol::Https | Protocol::Websocket => {
851 let entry = RouteEntry::from_endpoint(name, endpoint);
853 self.registry.register(entry).await;
854 http_ports.push(endpoint.port);
855
856 let lb_key = endpoint_lb_key(name, &endpoint.name);
863 self.load_balancer
864 .register(&lb_key, vec![], LbStrategy::RoundRobin);
865
866 info!(
867 service = name,
868 endpoint = %endpoint.name,
869 protocol = ?endpoint.protocol,
870 path = ?endpoint.path,
871 expose = ?endpoint.expose,
872 "Added HTTP proxy route for service"
873 );
874 }
875 Protocol::Tcp => {
876 tcp_ports.push(endpoint.port);
877 info!(
878 service = name,
879 endpoint = %endpoint.name,
880 protocol = ?endpoint.protocol,
881 port = endpoint.port,
882 expose = ?endpoint.expose,
883 "Tracking TCP stream endpoint for service"
884 );
885 }
886 Protocol::Udp => {
887 udp_ports.push(endpoint.port);
888 info!(
889 service = name,
890 endpoint = %endpoint.name,
891 protocol = ?endpoint.protocol,
892 port = endpoint.port,
893 expose = ?endpoint.expose,
894 "Tracking UDP stream endpoint for service"
895 );
896 }
897 }
898
899 endpoint_names.push(endpoint.name.clone());
900 }
901
902 self.load_balancer
909 .register(name, vec![], LbStrategy::RoundRobin);
910
911 services.insert(
912 name.to_string(),
913 ServiceTracking {
914 endpoint_names,
915 tcp_ports,
916 udp_ports,
917 http_ports,
918 },
919 );
920 }
921
922 pub async fn remove_service(&self, name: &str) {
932 let mut services = self.services.write().await;
933
934 if let Some(tracking) = services.remove(name) {
935 self.registry.unregister_service(name).await;
937
938 self.load_balancer.unregister(name);
941 for endpoint_name in &tracking.endpoint_names {
942 let lb_key = endpoint_lb_key(name, endpoint_name);
943 self.load_balancer.unregister(&lb_key);
944 }
945
946 if !tracking.tcp_ports.is_empty() {
948 let mut tcp_set = self.tcp_listeners.write().await;
949 for port in &tracking.tcp_ports {
950 if let Some(registry) = &self.stream_registry {
951 let _ = registry.unregister_tcp(*port);
952 }
953 tcp_set.remove(port);
954 debug!(service = name, port = port, "Removed TCP listener tracking");
955 }
956 }
957
958 if !tracking.udp_ports.is_empty() {
960 let mut udp_set = self.udp_listeners.write().await;
961 for port in &tracking.udp_ports {
962 if let Some(registry) = &self.stream_registry {
963 let _ = registry.unregister_udp(*port);
964 }
965 udp_set.remove(port);
966 debug!(service = name, port = port, "Removed UDP listener tracking");
967 }
968 }
969
970 if !tracking.http_ports.is_empty() {
973 let ports_still_in_use: HashSet<u16> = services
974 .values()
975 .flat_map(|t| t.http_ports.iter().copied())
976 .collect();
977
978 let mut servers = self.servers.write().await;
979 for port in &tracking.http_ports {
980 if !ports_still_in_use.contains(port) {
981 if let Some(server) = servers.remove(port) {
982 server.shutdown();
983 info!(
984 service = name,
985 port = port,
986 "Shut down HTTP proxy server (no remaining services on port)"
987 );
988 }
989 }
990 }
991 }
992
993 info!(service = name, "Removed all proxy resources for service");
994 }
995 }
996
997 pub async fn add_backend(&self, service: &str, addr: SocketAddr) {
1004 self.registry.add_backend(service, addr).await;
1005 self.load_balancer.add_backend(service, addr);
1006 let services = self.services.read().await;
1008 if let Some(tracking) = services.get(service) {
1009 for endpoint_name in &tracking.endpoint_names {
1010 let lb_key = endpoint_lb_key(service, endpoint_name);
1011 self.load_balancer.add_backend(&lb_key, addr);
1012 }
1013 }
1014 info!(service = service, backend = %addr, "Registered backend with proxy");
1015 }
1016
1017 pub async fn remove_backend(&self, service: &str, addr: SocketAddr) {
1022 self.registry.remove_backend(service, addr).await;
1023 self.load_balancer.remove_backend(service, &addr);
1024 let services = self.services.read().await;
1025 if let Some(tracking) = services.get(service) {
1026 for endpoint_name in &tracking.endpoint_names {
1027 let lb_key = endpoint_lb_key(service, endpoint_name);
1028 self.load_balancer.remove_backend(&lb_key, &addr);
1029 }
1030 }
1031 debug!(service = service, backend = %addr, "Removed backend from service");
1032 }
1033
1034 #[allow(clippy::unused_async)]
1041 pub async fn update_backend_health(&self, service: &str, addr: SocketAddr, healthy: bool) {
1042 self.load_balancer.mark_health(service, &addr, healthy);
1043 let services = self.services.read().await;
1044 if let Some(tracking) = services.get(service) {
1045 for endpoint_name in &tracking.endpoint_names {
1046 let lb_key = endpoint_lb_key(service, endpoint_name);
1047 self.load_balancer.mark_health(&lb_key, &addr, healthy);
1048 }
1049 }
1050 debug!(
1051 service = service,
1052 backend = %addr,
1053 healthy = healthy,
1054 "Updated backend health in load balancer"
1055 );
1056 }
1057
1058 pub async fn update_backends(&self, service: &str, addrs: Vec<SocketAddr>) {
1066 self.registry.update_backends(service, addrs.clone()).await;
1067 self.load_balancer.update_backends(service, addrs.clone());
1069 let services = self.services.read().await;
1070 if let Some(tracking) = services.get(service) {
1071 for endpoint_name in &tracking.endpoint_names {
1072 let lb_key = endpoint_lb_key(service, endpoint_name);
1073 self.load_balancer.update_backends(&lb_key, addrs.clone());
1074 }
1075 }
1076 debug!(service = service, "Updated backends for service");
1077 }
1078
1079 pub async fn update_endpoint_backends(
1086 &self,
1087 service: &str,
1088 endpoint_name: &str,
1089 addrs: Vec<SocketAddr>,
1090 ) {
1091 self.registry
1092 .update_backends_for_endpoint(service, endpoint_name, addrs.clone())
1093 .await;
1094 let lb_key = endpoint_lb_key(service, endpoint_name);
1095 self.load_balancer.update_backends(&lb_key, addrs);
1096 debug!(
1097 service = service,
1098 endpoint = endpoint_name,
1099 "Updated backends for service endpoint"
1100 );
1101 }
1102
1103 pub async fn route_count(&self) -> usize {
1105 self.registry.route_count().await
1106 }
1107
1108 pub async fn list_services(&self) -> Vec<String> {
1110 self.services.read().await.keys().cloned().collect()
1111 }
1112
1113 pub async fn has_service(&self, name: &str) -> bool {
1115 self.services.read().await.contains_key(name)
1116 }
1117}
1118
1119#[cfg(test)]
1120mod tests {
1121 use super::*;
1122
1123 fn mock_service_spec_with_endpoints() -> ServiceSpec {
1124 use zlayer_spec::*;
1125 serde_yaml::from_str::<DeploymentSpec>(
1126 r"
1127version: v1
1128deployment: test
1129services:
1130 test:
1131 rtype: service
1132 image:
1133 name: test:latest
1134 endpoints:
1135 - name: http
1136 protocol: http
1137 port: 8080
1138 path: /api
1139 expose: public
1140 - name: websocket
1141 protocol: websocket
1142 port: 8081
1143 path: /ws
1144 expose: internal
1145",
1146 )
1147 .unwrap()
1148 .services
1149 .remove("test")
1150 .unwrap()
1151 }
1152
1153 fn mock_service_spec_tcp_only() -> ServiceSpec {
1154 mock_service_spec_tcp_only_port(9000)
1155 }
1156
1157 fn mock_service_spec_tcp_only_port(port: u16) -> ServiceSpec {
1158 use zlayer_spec::*;
1159 let yaml = format!(
1160 "
1161version: v1
1162deployment: test
1163services:
1164 test:
1165 rtype: service
1166 image:
1167 name: test:latest
1168 endpoints:
1169 - name: grpc
1170 protocol: tcp
1171 port: {port}
1172"
1173 );
1174 serde_yaml::from_str::<DeploymentSpec>(&yaml)
1175 .unwrap()
1176 .services
1177 .remove("test")
1178 .unwrap()
1179 }
1180
1181 fn reserve_free_tcp_port() -> u16 {
1189 let listener =
1190 std::net::TcpListener::bind("127.0.0.1:0").expect("failed to bind ephemeral test port");
1191 listener.local_addr().unwrap().port()
1192 }
1193
1194 #[tokio::test]
1195 async fn test_proxy_manager_new() {
1196 let config = ProxyManagerConfig::default();
1197 let registry = Arc::new(ServiceRegistry::new());
1198 let manager = ProxyManager::new(config, registry, None);
1199
1200 assert_eq!(manager.route_count().await, 0);
1201 assert!(manager.list_services().await.is_empty());
1202 }
1203
1204 #[tokio::test]
1205 async fn test_add_service_with_http_endpoints() {
1206 let config = ProxyManagerConfig::default();
1207 let registry = Arc::new(ServiceRegistry::new());
1208 let manager = ProxyManager::new(config, registry, None);
1209
1210 let spec = mock_service_spec_with_endpoints();
1211 manager.add_service("api", &spec).await;
1212
1213 assert_eq!(manager.route_count().await, 2);
1215 assert!(manager.has_service("api").await);
1216 }
1217
1218 #[tokio::test]
1219 async fn test_tcp_endpoints_tracked_not_routed() {
1220 let config = ProxyManagerConfig::default();
1221 let registry = Arc::new(ServiceRegistry::new());
1222 let manager = ProxyManager::new(config, registry, None);
1223
1224 let spec = mock_service_spec_tcp_only();
1225 manager.add_service("grpc-service", &spec).await;
1226
1227 assert_eq!(manager.route_count().await, 0);
1229 assert!(manager.has_service("grpc-service").await);
1231 }
1232
1233 #[tokio::test]
1234 async fn test_remove_service() {
1235 let config = ProxyManagerConfig::default();
1236 let registry = Arc::new(ServiceRegistry::new());
1237 let manager = ProxyManager::new(config, registry, None);
1238
1239 let spec = mock_service_spec_with_endpoints();
1240 manager.add_service("api", &spec).await;
1241 assert_eq!(manager.route_count().await, 2);
1242
1243 manager.remove_service("api").await;
1244 assert_eq!(manager.route_count().await, 0);
1245 assert!(!manager.has_service("api").await);
1246 }
1247
1248 #[tokio::test]
1249 async fn test_backend_management() {
1250 let config = ProxyManagerConfig::default();
1251 let registry = Arc::new(ServiceRegistry::new());
1252 let manager = ProxyManager::new(config, registry.clone(), None);
1253
1254 let spec = mock_service_spec_with_endpoints();
1255 manager.add_service("api", &spec).await;
1256
1257 let addr1: SocketAddr = "127.0.0.1:8080".parse().unwrap();
1259 let addr2: SocketAddr = "127.0.0.1:8081".parse().unwrap();
1260
1261 manager.add_backend("api", addr1).await;
1262 manager.add_backend("api", addr2).await;
1263
1264 let resolved = registry.resolve(None, "/api").await.unwrap();
1266 assert_eq!(resolved.backends.len(), 2);
1267
1268 manager.remove_backend("api", addr1).await;
1270 let resolved = registry.resolve(None, "/api").await.unwrap();
1271 assert_eq!(resolved.backends.len(), 1);
1272 }
1273
1274 #[tokio::test]
1275 async fn test_update_backends_replaces_all() {
1276 let config = ProxyManagerConfig::default();
1277 let registry = Arc::new(ServiceRegistry::new());
1278 let manager = ProxyManager::new(config, registry.clone(), None);
1279
1280 let spec = mock_service_spec_with_endpoints();
1281 manager.add_service("api", &spec).await;
1282
1283 let addr1: SocketAddr = "127.0.0.1:8080".parse().unwrap();
1285 manager.add_backend("api", addr1).await;
1286
1287 let new_backends: Vec<SocketAddr> = vec![
1289 "127.0.0.1:9000".parse().unwrap(),
1290 "127.0.0.1:9001".parse().unwrap(),
1291 "127.0.0.1:9002".parse().unwrap(),
1292 ];
1293 manager.update_backends("api", new_backends).await;
1294
1295 let resolved = registry.resolve(None, "/api").await.unwrap();
1296 assert_eq!(resolved.backends.len(), 3);
1297 }
1298
1299 #[tokio::test]
1300 async fn test_config_builder() {
1301 let config = ProxyManagerConfig::new("0.0.0.0:8080".parse().unwrap())
1302 .with_https("0.0.0.0:8443".parse().unwrap())
1303 .with_http2(false);
1304
1305 assert_eq!(
1306 config.http_addr,
1307 "0.0.0.0:8080".parse::<SocketAddr>().unwrap()
1308 );
1309 assert_eq!(
1310 config.https_addr,
1311 Some("0.0.0.0:8443".parse::<SocketAddr>().unwrap())
1312 );
1313 assert!(!config.http2_enabled);
1314 }
1315
1316 #[tokio::test]
1321 async fn test_ensure_ports_differentiates_public_and_internal() {
1322 let config = ProxyManagerConfig::default();
1323 let registry = Arc::new(ServiceRegistry::new());
1324 let manager = ProxyManager::new(config, registry, None);
1325
1326 let spec = mock_service_spec_with_endpoints();
1327 let result = manager.ensure_ports_for_service(&spec, None).await;
1329 let _ = result;
1332 }
1333
1334 #[tokio::test]
1335 async fn test_ensure_ports_with_overlay_ip() {
1336 let config = ProxyManagerConfig::default();
1337 let registry = Arc::new(ServiceRegistry::new());
1338 let manager = ProxyManager::new(config, registry, None);
1339
1340 let spec = mock_service_spec_with_endpoints();
1341 let overlay_ip: IpAddr = "10.200.0.5".parse().unwrap();
1343 let result = manager
1344 .ensure_ports_for_service(&spec, Some(overlay_ip))
1345 .await;
1346 let _ = result;
1347 }
1348
1349 fn mock_mixed_service_spec() -> ServiceSpec {
1350 use zlayer_spec::*;
1351 serde_yaml::from_str::<DeploymentSpec>(
1352 r"
1353version: v1
1354deployment: test
1355services:
1356 mixed:
1357 rtype: service
1358 image:
1359 name: test:latest
1360 endpoints:
1361 - name: http
1362 protocol: http
1363 port: 8080
1364 path: /api
1365 expose: public
1366 - name: grpc
1367 protocol: tcp
1368 port: 9000
1369 expose: public
1370 - name: game
1371 protocol: udp
1372 port: 27015
1373 expose: public
1374",
1375 )
1376 .unwrap()
1377 .services
1378 .remove("mixed")
1379 .unwrap()
1380 }
1381
1382 #[tokio::test]
1383 async fn test_add_mixed_service_tracks_all_endpoints() {
1384 let config = ProxyManagerConfig::default();
1385 let registry = Arc::new(ServiceRegistry::new());
1386 let manager = ProxyManager::new(config, registry, None);
1387
1388 let spec = mock_mixed_service_spec();
1389 manager.add_service("mixed", &spec).await;
1390
1391 assert_eq!(manager.route_count().await, 1);
1393 assert!(manager.has_service("mixed").await);
1395 }
1396
1397 #[tokio::test]
1398 async fn test_ensure_ports_tcp_with_stream_registry() {
1399 use zlayer_proxy::StreamService;
1400
1401 let stream_registry = Arc::new(StreamRegistry::new());
1402 let config = ProxyManagerConfig::default();
1403 let registry = Arc::new(ServiceRegistry::new());
1404 let mut manager = ProxyManager::new(config, registry, None);
1405 manager.set_stream_registry(stream_registry.clone());
1406
1407 let port = reserve_free_tcp_port();
1411 let spec = mock_service_spec_tcp_only_port(port);
1412
1413 stream_registry.register_tcp(port, StreamService::new("grpc-service".to_string(), vec![]));
1415
1416 let result = manager.ensure_ports_for_service(&spec, None).await;
1418 assert!(result.is_ok());
1419
1420 let tcp_ports = manager.tcp_listeners.read().await;
1422 assert!(tcp_ports.contains(&port));
1423 }
1424
1425 #[tokio::test]
1426 async fn test_ensure_ports_tcp_without_stream_registry() {
1427 let config = ProxyManagerConfig::default();
1428 let registry = Arc::new(ServiceRegistry::new());
1429 let manager = ProxyManager::new(config, registry, None);
1430
1431 let spec = mock_service_spec_tcp_only();
1432
1433 let result = manager.ensure_ports_for_service(&spec, None).await;
1435 assert!(result.is_ok());
1436
1437 let tcp_ports = manager.tcp_listeners.read().await;
1439 assert!(tcp_ports.is_empty());
1440 }
1441
1442 #[tokio::test]
1443 async fn test_stream_registry_setter() {
1444 let stream_registry = Arc::new(StreamRegistry::new());
1445 let config = ProxyManagerConfig::default();
1446 let registry = Arc::new(ServiceRegistry::new());
1447 let mut manager = ProxyManager::new(config, registry, None);
1448
1449 assert!(manager.stream_registry().is_none());
1450 manager.set_stream_registry(stream_registry.clone());
1451 assert!(manager.stream_registry().is_some());
1452 }
1453
1454 fn mock_internal_tcp_spec(port: u16) -> ServiceSpec {
1457 use zlayer_spec::*;
1458 let yaml = format!(
1459 "
1460version: v1
1461deployment: test
1462services:
1463 test:
1464 rtype: service
1465 image:
1466 name: test:latest
1467 scale:
1468 mode: fixed
1469 replicas: 1
1470 endpoints:
1471 - name: tcp
1472 protocol: tcp
1473 port: {port}
1474 expose: internal
1475"
1476 );
1477 serde_yaml::from_str::<DeploymentSpec>(&yaml)
1478 .unwrap()
1479 .services
1480 .remove("test")
1481 .unwrap()
1482 }
1483
1484 #[tokio::test]
1489 async fn test_publish_loopback_round_trips_then_frees_port() {
1490 use tokio::io::{AsyncReadExt, AsyncWriteExt};
1491
1492 let backend = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
1494 let backend_addr = backend.local_addr().unwrap();
1495 let backend_ip = backend_addr.ip();
1496 let backend_port = backend_addr.port();
1497 tokio::spawn(async move {
1498 if let Ok((mut sock, _)) = backend.accept().await {
1499 let mut buf = [0u8; 16];
1500 let n = sock.read(&mut buf).await.unwrap_or(0);
1501 let _ = sock.write_all(b"pong:").await;
1503 let _ = sock.write_all(&buf[..n]).await;
1504 let _ = sock.flush().await;
1505 }
1506 });
1507
1508 let config = ProxyManagerConfig::default();
1509 let registry = Arc::new(ServiceRegistry::new());
1510 let manager = ProxyManager::new(config, registry, None);
1511
1512 let publish_port = reserve_free_tcp_port();
1514 let spec = mock_internal_tcp_spec(publish_port);
1515 assert!(
1516 spec.publish_to_node_loopback(),
1517 "single-member internal spec should publish to loopback"
1518 );
1519
1520 manager
1523 .publish_loopback_for_container("test", &spec, backend_ip, Some(backend_port))
1524 .await;
1525
1526 let mut client = tokio::net::TcpStream::connect((Ipv4Addr::LOCALHOST, publish_port))
1528 .await
1529 .expect("connect to published loopback port");
1530 client.write_all(b"ping").await.unwrap();
1531 client.flush().await.unwrap();
1532 let mut reply = Vec::new();
1533 client.read_to_end(&mut reply).await.unwrap();
1534 assert_eq!(&reply, b"pong:ping");
1535 drop(client);
1536
1537 manager
1539 .unpublish_loopback_for_container(&spec, backend_ip, Some(backend_port))
1540 .await;
1541
1542 let mut bound = None;
1545 for _ in 0..50 {
1546 match std::net::TcpListener::bind((Ipv4Addr::LOCALHOST, publish_port)) {
1547 Ok(l) => {
1548 bound = Some(l);
1549 break;
1550 }
1551 Err(_) => tokio::time::sleep(Duration::from_millis(20)).await,
1552 }
1553 }
1554 assert!(
1555 bound.is_some(),
1556 "loopback port {publish_port} should be freed after unpublish"
1557 );
1558 }
1559
1560 #[tokio::test]
1561 async fn test_publish_loopback_skips_public_endpoints() {
1562 let config = ProxyManagerConfig::default();
1566 let registry = Arc::new(ServiceRegistry::new());
1567 let manager = ProxyManager::new(config, registry, None);
1568
1569 let spec = mock_mixed_service_spec();
1570 let backend_ip: IpAddr = "127.0.0.1".parse().unwrap();
1571 manager
1572 .publish_loopback_for_container("mixed", &spec, backend_ip, None)
1573 .await;
1574
1575 assert!(manager.loopback_tcp.read().await.is_empty());
1577 assert!(manager.loopback_udp.read().await.is_empty());
1578 }
1579
1580 #[tokio::test]
1581 async fn test_registry_accessor() {
1582 let config = ProxyManagerConfig::default();
1583 let registry = Arc::new(ServiceRegistry::new());
1584 let manager = ProxyManager::new(config, registry.clone(), None);
1585
1586 assert_eq!(Arc::as_ptr(&manager.registry()), Arc::as_ptr(®istry));
1588 }
1589}