1use crate::error::Result;
11use std::collections::{HashMap, HashSet};
12use std::net::{IpAddr, Ipv4Addr, SocketAddr};
13use std::sync::atomic::{AtomicU64, Ordering};
14use std::sync::Arc;
15use std::time::Duration;
16use tokio::sync::RwLock;
17use tracing::{debug, info, warn};
18use zlayer_proxy::{
19 load_existing_certs_into_resolver, CertManager, LbStrategy, LoadBalancer, NetworkPolicyChecker,
20 ProxyConfig, ProxyServer, RouteEntry, ServiceRegistry, SniCertResolver, StreamRegistry,
21 TcpStreamService, UdpStreamService,
22};
23use zlayer_spec::{ExposeType, Protocol, ServiceSpec};
24
25#[derive(Debug, Clone)]
27pub struct ProxyManagerConfig {
28 pub http_addr: SocketAddr,
30 pub https_addr: Option<SocketAddr>,
32 pub http2_enabled: bool,
34}
35
36impl Default for ProxyManagerConfig {
37 fn default() -> Self {
38 Self {
39 http_addr: "0.0.0.0:80".parse().unwrap(),
40 https_addr: None,
41 http2_enabled: true,
42 }
43 }
44}
45
46impl ProxyManagerConfig {
47 #[must_use]
49 pub fn new(http_addr: SocketAddr) -> Self {
50 Self {
51 http_addr,
52 https_addr: None,
53 http2_enabled: true,
54 }
55 }
56
57 #[must_use]
59 pub fn with_https(mut self, addr: SocketAddr) -> Self {
60 self.https_addr = Some(addr);
61 self
62 }
63
64 #[must_use]
66 pub fn with_http2(mut self, enabled: bool) -> Self {
67 self.http2_enabled = enabled;
68 self
69 }
70}
71
72#[derive(Debug, Clone)]
74struct ServiceTracking {
75 #[allow(dead_code)]
77 endpoint_names: Vec<String>,
78 tcp_ports: Vec<u16>,
80 udp_ports: Vec<u16>,
82 http_ports: Vec<u16>,
84}
85
86pub struct ProxyManager {
96 config: ProxyManagerConfig,
98 registry: Arc<ServiceRegistry>,
100 load_balancer: Arc<LoadBalancer>,
102 servers: RwLock<HashMap<u16, Arc<ProxyServer>>>,
104 services: RwLock<HashMap<String, ServiceTracking>>,
106 stream_registry: Option<Arc<StreamRegistry>>,
108 cert_manager: Option<Arc<CertManager>>,
110 tcp_listeners: RwLock<HashSet<u16>>,
112 udp_listeners: RwLock<HashSet<u16>>,
114 active_connections: Arc<AtomicU64>,
116 network_policy_checker: Option<NetworkPolicyChecker>,
118}
119
120impl ProxyManager {
121 pub fn new(
124 config: ProxyManagerConfig,
125 registry: Arc<ServiceRegistry>,
126 cert_manager: Option<Arc<CertManager>>,
127 ) -> Self {
128 let load_balancer = Arc::new(LoadBalancer::new());
129
130 Self {
131 config,
132 registry,
133 load_balancer,
134 servers: RwLock::new(HashMap::new()),
135 services: RwLock::new(HashMap::new()),
136 stream_registry: None,
137 cert_manager,
138 tcp_listeners: RwLock::new(HashSet::new()),
139 udp_listeners: RwLock::new(HashSet::new()),
140 active_connections: Arc::new(AtomicU64::new(0)),
141 network_policy_checker: None,
142 }
143 }
144
145 pub fn registry(&self) -> Arc<ServiceRegistry> {
147 self.registry.clone()
148 }
149
150 pub fn load_balancer(&self) -> Arc<LoadBalancer> {
152 self.load_balancer.clone()
153 }
154
155 pub fn active_connections(&self) -> u64 {
157 self.active_connections.load(Ordering::Relaxed)
158 }
159
160 pub fn cert_manager(&self) -> Option<&Arc<CertManager>> {
162 self.cert_manager.as_ref()
163 }
164
165 pub fn set_stream_registry(&mut self, registry: Arc<StreamRegistry>) {
167 self.stream_registry = Some(registry);
168 }
169
170 #[must_use]
172 pub fn with_stream_registry(mut self, registry: Arc<StreamRegistry>) -> Self {
173 self.stream_registry = Some(registry);
174 self
175 }
176
177 pub fn stream_registry(&self) -> Option<&Arc<StreamRegistry>> {
179 self.stream_registry.as_ref()
180 }
181
182 pub fn set_network_policy_checker(&mut self, checker: NetworkPolicyChecker) {
184 self.network_policy_checker = Some(checker);
185 }
186
187 #[must_use]
189 pub fn with_network_policy_checker(mut self, checker: NetworkPolicyChecker) -> Self {
190 self.network_policy_checker = Some(checker);
191 self
192 }
193
194 pub async fn listen_on(&self, port: u16, bind_ip: IpAddr) -> Result<()> {
202 let mut servers = self.servers.write().await;
203
204 if servers.contains_key(&port) {
205 debug!(port = port, "Already listening on port");
206 return Ok(());
207 }
208
209 let addr = SocketAddr::new(bind_ip, port);
210 let mut proxy_config = ProxyConfig::default();
211 proxy_config.server.http_addr = addr;
212 proxy_config.server.http2_enabled = self.config.http2_enabled;
213
214 let mut server = ProxyServer::with_registry(
215 proxy_config,
216 self.registry.clone(),
217 self.load_balancer.clone(),
218 );
219 if let Some(ref checker) = self.network_policy_checker {
220 server = server.with_network_policy_checker(checker.clone());
221 }
222 let server = Arc::new(server);
223
224 info!(port = port, bind = %addr, "Proxy listening on port");
225
226 let server_clone = server.clone();
227 tokio::spawn(async move {
228 if let Err(e) = server_clone.run().await {
229 tracing::error!(port = port, error = %e, "Proxy server error on port");
230 }
231 });
232
233 servers.insert(port, server);
234 Ok(())
235 }
236
237 pub async fn listen_on_tls(&self, port: u16, bind_ip: IpAddr) -> Result<()> {
245 let mut servers = self.servers.write().await;
246
247 if servers.contains_key(&port) {
248 debug!(port = port, "Already listening on port (TLS)");
249 return Ok(());
250 }
251
252 let Some(cert_manager) = &self.cert_manager else {
253 warn!(
254 port = port,
255 "Cannot start TLS listener: no CertManager configured"
256 );
257 return Ok(());
258 };
259
260 let sni_resolver = Arc::new(SniCertResolver::new());
262
263 let _ = load_existing_certs_into_resolver(cert_manager, &sni_resolver).await;
265
266 let addr = SocketAddr::new(bind_ip, port);
267 let mut proxy_config = ProxyConfig::default();
268 proxy_config.server.https_addr = addr;
269
270 let mut server = ProxyServer::with_tls_resolver(
271 proxy_config,
272 self.registry.clone(),
273 self.load_balancer.clone(),
274 sni_resolver,
275 )
276 .with_cert_manager(Arc::clone(cert_manager));
277 if let Some(ref checker) = self.network_policy_checker {
278 server = server.with_network_policy_checker(checker.clone());
279 }
280 let server = Arc::new(server);
281
282 info!(port = port, bind = %addr, "HTTPS proxy listening on port");
283
284 let server_clone = server.clone();
285 tokio::spawn(async move {
286 if let Err(e) = server_clone.run_https().await {
287 tracing::error!(port = port, error = %e, "HTTPS proxy server error");
288 }
289 });
290
291 servers.insert(port, server);
292 Ok(())
293 }
294
295 pub async fn stop(&self) {
300 let mut servers = self.servers.write().await;
301 for (port, server) in servers.drain() {
302 info!(port = port, "Stopping proxy on port");
303 server.shutdown();
304 }
305
306 let deadline = tokio::time::Instant::now() + Duration::from_secs(30);
308 while self.active_connections.load(Ordering::Relaxed) > 0 {
309 if tokio::time::Instant::now() >= deadline {
310 let remaining = self.active_connections.load(Ordering::Relaxed);
311 warn!(
312 remaining = remaining,
313 "Drain timeout reached, forcing shutdown"
314 );
315 break;
316 }
317 tokio::time::sleep(Duration::from_millis(100)).await;
318 }
319
320 info!("All proxy servers stopped");
321 }
322
323 pub async fn unbind(&self, port: u16) {
325 let mut servers = self.servers.write().await;
326 if let Some(server) = servers.remove(&port) {
327 info!(port = port, "Unbinding proxy from port");
328 server.shutdown();
329 }
330 }
331
332 pub async fn ensure_ports_for_service(
348 &self,
349 spec: &ServiceSpec,
350 overlay_ip: Option<IpAddr>,
351 ) -> Result<()> {
352 for endpoint in &spec.endpoints {
353 let bind_ip = match endpoint.expose {
354 ExposeType::Public => IpAddr::V4(Ipv4Addr::UNSPECIFIED), ExposeType::Internal => {
356 let ip = overlay_ip.unwrap_or(IpAddr::V4(Ipv4Addr::LOCALHOST));
358 if overlay_ip.is_none() {
359 warn!(
360 endpoint = %endpoint.name,
361 port = endpoint.port,
362 "No overlay IP available for internal endpoint; binding to 127.0.0.1"
363 );
364 }
365 ip
366 }
367 };
368
369 match endpoint.protocol {
370 Protocol::Https => {
371 self.listen_on_tls(endpoint.port, bind_ip).await?;
373 }
374 Protocol::Http | Protocol::Websocket => {
375 self.listen_on(endpoint.port, bind_ip).await?;
377 }
378 Protocol::Tcp => {
379 self.ensure_tcp_listener(endpoint.port, bind_ip).await;
381 }
382 Protocol::Udp => {
383 self.ensure_udp_listener(endpoint.port, bind_ip).await;
385 }
386 }
387 }
388 Ok(())
389 }
390
391 async fn ensure_tcp_listener(&self, port: u16, bind_ip: IpAddr) {
396 {
398 let listeners = self.tcp_listeners.read().await;
399 if listeners.contains(&port) {
400 debug!(port = port, "TCP stream listener already active");
401 return;
402 }
403 }
404
405 let registry = if let Some(r) = &self.stream_registry {
406 Arc::clone(r)
407 } else {
408 warn!(
409 port = port,
410 "Cannot start TCP listener: StreamRegistry not configured"
411 );
412 return;
413 };
414
415 let addr = SocketAddr::new(bind_ip, port);
416 let listener = match tokio::net::TcpListener::bind(addr).await {
417 Ok(l) => l,
418 Err(e) => {
419 warn!(
420 port = port,
421 bind = %addr,
422 error = %e,
423 "Failed to bind TCP stream listener, continuing"
424 );
425 return;
426 }
427 };
428
429 {
431 let mut listeners = self.tcp_listeners.write().await;
432 listeners.insert(port);
433 }
434
435 let tcp_service = Arc::new(TcpStreamService::new(registry, port));
436 tokio::spawn(async move {
437 tcp_service.serve(listener).await;
438 });
439
440 info!(port = port, bind = %addr, "TCP stream proxy listening");
441 }
442
443 async fn ensure_udp_listener(&self, port: u16, bind_ip: IpAddr) {
448 {
450 let listeners = self.udp_listeners.read().await;
451 if listeners.contains(&port) {
452 debug!(port = port, "UDP stream listener already active");
453 return;
454 }
455 }
456
457 let registry = if let Some(r) = &self.stream_registry {
458 Arc::clone(r)
459 } else {
460 warn!(
461 port = port,
462 "Cannot start UDP listener: StreamRegistry not configured"
463 );
464 return;
465 };
466
467 let addr = SocketAddr::new(bind_ip, port);
468 let socket = match tokio::net::UdpSocket::bind(addr).await {
469 Ok(s) => s,
470 Err(e) => {
471 warn!(
472 port = port,
473 bind = %addr,
474 error = %e,
475 "Failed to bind UDP stream listener, continuing"
476 );
477 return;
478 }
479 };
480
481 {
483 let mut listeners = self.udp_listeners.write().await;
484 listeners.insert(port);
485 }
486
487 let udp_service = Arc::new(UdpStreamService::new(registry, port, None));
488 tokio::spawn(async move {
489 if let Err(e) = udp_service.serve(socket).await {
490 tracing::error!(
491 port = port,
492 error = %e,
493 "UDP stream proxy service failed"
494 );
495 }
496 });
497
498 info!(port = port, bind = %addr, "UDP stream proxy listening");
499 }
500
501 pub async fn add_service(&self, name: &str, spec: &ServiceSpec) {
508 let mut services = self.services.write().await;
509
510 let mut endpoint_names = Vec::new();
512 let mut tcp_ports = Vec::new();
513 let mut udp_ports = Vec::new();
514 let mut http_ports = Vec::new();
515
516 for endpoint in &spec.endpoints {
517 match endpoint.protocol {
518 Protocol::Http | Protocol::Https | Protocol::Websocket => {
519 let entry = RouteEntry::from_endpoint(name, endpoint);
521 self.registry.register(entry).await;
522 http_ports.push(endpoint.port);
523
524 info!(
525 service = name,
526 endpoint = %endpoint.name,
527 protocol = ?endpoint.protocol,
528 path = ?endpoint.path,
529 expose = ?endpoint.expose,
530 "Added HTTP proxy route for service"
531 );
532 }
533 Protocol::Tcp => {
534 tcp_ports.push(endpoint.port);
535 info!(
536 service = name,
537 endpoint = %endpoint.name,
538 protocol = ?endpoint.protocol,
539 port = endpoint.port,
540 expose = ?endpoint.expose,
541 "Tracking TCP stream endpoint for service"
542 );
543 }
544 Protocol::Udp => {
545 udp_ports.push(endpoint.port);
546 info!(
547 service = name,
548 endpoint = %endpoint.name,
549 protocol = ?endpoint.protocol,
550 port = endpoint.port,
551 expose = ?endpoint.expose,
552 "Tracking UDP stream endpoint for service"
553 );
554 }
555 }
556
557 endpoint_names.push(endpoint.name.clone());
558 }
559
560 self.load_balancer
562 .register(name, vec![], LbStrategy::RoundRobin);
563
564 services.insert(
565 name.to_string(),
566 ServiceTracking {
567 endpoint_names,
568 tcp_ports,
569 udp_ports,
570 http_ports,
571 },
572 );
573 }
574
575 pub async fn remove_service(&self, name: &str) {
585 let mut services = self.services.write().await;
586
587 if let Some(tracking) = services.remove(name) {
588 self.registry.unregister_service(name).await;
590
591 self.load_balancer.unregister(name);
593
594 if !tracking.tcp_ports.is_empty() {
596 let mut tcp_set = self.tcp_listeners.write().await;
597 for port in &tracking.tcp_ports {
598 if let Some(registry) = &self.stream_registry {
599 let _ = registry.unregister_tcp(*port);
600 }
601 tcp_set.remove(port);
602 debug!(service = name, port = port, "Removed TCP listener tracking");
603 }
604 }
605
606 if !tracking.udp_ports.is_empty() {
608 let mut udp_set = self.udp_listeners.write().await;
609 for port in &tracking.udp_ports {
610 if let Some(registry) = &self.stream_registry {
611 let _ = registry.unregister_udp(*port);
612 }
613 udp_set.remove(port);
614 debug!(service = name, port = port, "Removed UDP listener tracking");
615 }
616 }
617
618 if !tracking.http_ports.is_empty() {
621 let ports_still_in_use: HashSet<u16> = services
622 .values()
623 .flat_map(|t| t.http_ports.iter().copied())
624 .collect();
625
626 let mut servers = self.servers.write().await;
627 for port in &tracking.http_ports {
628 if !ports_still_in_use.contains(port) {
629 if let Some(server) = servers.remove(port) {
630 server.shutdown();
631 info!(
632 service = name,
633 port = port,
634 "Shut down HTTP proxy server (no remaining services on port)"
635 );
636 }
637 }
638 }
639 }
640
641 info!(service = name, "Removed all proxy resources for service");
642 }
643 }
644
645 pub async fn add_backend(&self, service: &str, addr: SocketAddr) {
647 self.registry.add_backend(service, addr).await;
648 self.load_balancer.add_backend(service, addr);
649 info!(service = service, backend = %addr, "Registered backend with proxy");
650 }
651
652 pub async fn remove_backend(&self, service: &str, addr: SocketAddr) {
654 self.registry.remove_backend(service, addr).await;
655 self.load_balancer.remove_backend(service, &addr);
656 debug!(service = service, backend = %addr, "Removed backend from service");
657 }
658
659 #[allow(clippy::unused_async)]
664 pub async fn update_backend_health(&self, service: &str, addr: SocketAddr, healthy: bool) {
665 self.load_balancer.mark_health(service, &addr, healthy);
666 debug!(
667 service = service,
668 backend = %addr,
669 healthy = healthy,
670 "Updated backend health in load balancer"
671 );
672 }
673
674 pub async fn update_backends(&self, service: &str, addrs: Vec<SocketAddr>) {
679 self.registry.update_backends(service, addrs.clone()).await;
680 self.load_balancer.update_backends(service, addrs);
681 debug!(service = service, "Updated backends for service");
682 }
683
684 pub async fn route_count(&self) -> usize {
686 self.registry.route_count().await
687 }
688
689 pub async fn list_services(&self) -> Vec<String> {
691 self.services.read().await.keys().cloned().collect()
692 }
693
694 pub async fn has_service(&self, name: &str) -> bool {
696 self.services.read().await.contains_key(name)
697 }
698}
699
700#[cfg(test)]
701mod tests {
702 use super::*;
703
704 fn mock_service_spec_with_endpoints() -> ServiceSpec {
705 use zlayer_spec::*;
706 serde_yaml::from_str::<DeploymentSpec>(
707 r"
708version: v1
709deployment: test
710services:
711 test:
712 rtype: service
713 image:
714 name: test:latest
715 endpoints:
716 - name: http
717 protocol: http
718 port: 8080
719 path: /api
720 expose: public
721 - name: websocket
722 protocol: websocket
723 port: 8081
724 path: /ws
725 expose: internal
726",
727 )
728 .unwrap()
729 .services
730 .remove("test")
731 .unwrap()
732 }
733
734 fn mock_service_spec_tcp_only() -> ServiceSpec {
735 mock_service_spec_tcp_only_port(9000)
736 }
737
738 fn mock_service_spec_tcp_only_port(port: u16) -> ServiceSpec {
739 use zlayer_spec::*;
740 let yaml = format!(
741 "
742version: v1
743deployment: test
744services:
745 test:
746 rtype: service
747 image:
748 name: test:latest
749 endpoints:
750 - name: grpc
751 protocol: tcp
752 port: {port}
753"
754 );
755 serde_yaml::from_str::<DeploymentSpec>(&yaml)
756 .unwrap()
757 .services
758 .remove("test")
759 .unwrap()
760 }
761
762 fn reserve_free_tcp_port() -> u16 {
770 let listener =
771 std::net::TcpListener::bind("127.0.0.1:0").expect("failed to bind ephemeral test port");
772 listener.local_addr().unwrap().port()
773 }
774
775 #[tokio::test]
776 async fn test_proxy_manager_new() {
777 let config = ProxyManagerConfig::default();
778 let registry = Arc::new(ServiceRegistry::new());
779 let manager = ProxyManager::new(config, registry, None);
780
781 assert_eq!(manager.route_count().await, 0);
782 assert!(manager.list_services().await.is_empty());
783 }
784
785 #[tokio::test]
786 async fn test_add_service_with_http_endpoints() {
787 let config = ProxyManagerConfig::default();
788 let registry = Arc::new(ServiceRegistry::new());
789 let manager = ProxyManager::new(config, registry, None);
790
791 let spec = mock_service_spec_with_endpoints();
792 manager.add_service("api", &spec).await;
793
794 assert_eq!(manager.route_count().await, 2);
796 assert!(manager.has_service("api").await);
797 }
798
799 #[tokio::test]
800 async fn test_tcp_endpoints_tracked_not_routed() {
801 let config = ProxyManagerConfig::default();
802 let registry = Arc::new(ServiceRegistry::new());
803 let manager = ProxyManager::new(config, registry, None);
804
805 let spec = mock_service_spec_tcp_only();
806 manager.add_service("grpc-service", &spec).await;
807
808 assert_eq!(manager.route_count().await, 0);
810 assert!(manager.has_service("grpc-service").await);
812 }
813
814 #[tokio::test]
815 async fn test_remove_service() {
816 let config = ProxyManagerConfig::default();
817 let registry = Arc::new(ServiceRegistry::new());
818 let manager = ProxyManager::new(config, registry, None);
819
820 let spec = mock_service_spec_with_endpoints();
821 manager.add_service("api", &spec).await;
822 assert_eq!(manager.route_count().await, 2);
823
824 manager.remove_service("api").await;
825 assert_eq!(manager.route_count().await, 0);
826 assert!(!manager.has_service("api").await);
827 }
828
829 #[tokio::test]
830 async fn test_backend_management() {
831 let config = ProxyManagerConfig::default();
832 let registry = Arc::new(ServiceRegistry::new());
833 let manager = ProxyManager::new(config, registry.clone(), None);
834
835 let spec = mock_service_spec_with_endpoints();
836 manager.add_service("api", &spec).await;
837
838 let addr1: SocketAddr = "127.0.0.1:8080".parse().unwrap();
840 let addr2: SocketAddr = "127.0.0.1:8081".parse().unwrap();
841
842 manager.add_backend("api", addr1).await;
843 manager.add_backend("api", addr2).await;
844
845 let resolved = registry.resolve(None, "/api").await.unwrap();
847 assert_eq!(resolved.backends.len(), 2);
848
849 manager.remove_backend("api", addr1).await;
851 let resolved = registry.resolve(None, "/api").await.unwrap();
852 assert_eq!(resolved.backends.len(), 1);
853 }
854
855 #[tokio::test]
856 async fn test_update_backends_replaces_all() {
857 let config = ProxyManagerConfig::default();
858 let registry = Arc::new(ServiceRegistry::new());
859 let manager = ProxyManager::new(config, registry.clone(), None);
860
861 let spec = mock_service_spec_with_endpoints();
862 manager.add_service("api", &spec).await;
863
864 let addr1: SocketAddr = "127.0.0.1:8080".parse().unwrap();
866 manager.add_backend("api", addr1).await;
867
868 let new_backends: Vec<SocketAddr> = vec![
870 "127.0.0.1:9000".parse().unwrap(),
871 "127.0.0.1:9001".parse().unwrap(),
872 "127.0.0.1:9002".parse().unwrap(),
873 ];
874 manager.update_backends("api", new_backends).await;
875
876 let resolved = registry.resolve(None, "/api").await.unwrap();
877 assert_eq!(resolved.backends.len(), 3);
878 }
879
880 #[tokio::test]
881 async fn test_config_builder() {
882 let config = ProxyManagerConfig::new("0.0.0.0:8080".parse().unwrap())
883 .with_https("0.0.0.0:8443".parse().unwrap())
884 .with_http2(false);
885
886 assert_eq!(
887 config.http_addr,
888 "0.0.0.0:8080".parse::<SocketAddr>().unwrap()
889 );
890 assert_eq!(
891 config.https_addr,
892 Some("0.0.0.0:8443".parse::<SocketAddr>().unwrap())
893 );
894 assert!(!config.http2_enabled);
895 }
896
897 #[tokio::test]
902 async fn test_ensure_ports_differentiates_public_and_internal() {
903 let config = ProxyManagerConfig::default();
904 let registry = Arc::new(ServiceRegistry::new());
905 let manager = ProxyManager::new(config, registry, None);
906
907 let spec = mock_service_spec_with_endpoints();
908 let result = manager.ensure_ports_for_service(&spec, None).await;
910 let _ = result;
913 }
914
915 #[tokio::test]
916 async fn test_ensure_ports_with_overlay_ip() {
917 let config = ProxyManagerConfig::default();
918 let registry = Arc::new(ServiceRegistry::new());
919 let manager = ProxyManager::new(config, registry, None);
920
921 let spec = mock_service_spec_with_endpoints();
922 let overlay_ip: IpAddr = "10.200.0.5".parse().unwrap();
924 let result = manager
925 .ensure_ports_for_service(&spec, Some(overlay_ip))
926 .await;
927 let _ = result;
928 }
929
930 fn mock_mixed_service_spec() -> ServiceSpec {
931 use zlayer_spec::*;
932 serde_yaml::from_str::<DeploymentSpec>(
933 r"
934version: v1
935deployment: test
936services:
937 mixed:
938 rtype: service
939 image:
940 name: test:latest
941 endpoints:
942 - name: http
943 protocol: http
944 port: 8080
945 path: /api
946 expose: public
947 - name: grpc
948 protocol: tcp
949 port: 9000
950 expose: public
951 - name: game
952 protocol: udp
953 port: 27015
954 expose: public
955",
956 )
957 .unwrap()
958 .services
959 .remove("mixed")
960 .unwrap()
961 }
962
963 #[tokio::test]
964 async fn test_add_mixed_service_tracks_all_endpoints() {
965 let config = ProxyManagerConfig::default();
966 let registry = Arc::new(ServiceRegistry::new());
967 let manager = ProxyManager::new(config, registry, None);
968
969 let spec = mock_mixed_service_spec();
970 manager.add_service("mixed", &spec).await;
971
972 assert_eq!(manager.route_count().await, 1);
974 assert!(manager.has_service("mixed").await);
976 }
977
978 #[tokio::test]
979 async fn test_ensure_ports_tcp_with_stream_registry() {
980 use zlayer_proxy::StreamService;
981
982 let stream_registry = Arc::new(StreamRegistry::new());
983 let config = ProxyManagerConfig::default();
984 let registry = Arc::new(ServiceRegistry::new());
985 let mut manager = ProxyManager::new(config, registry, None);
986 manager.set_stream_registry(stream_registry.clone());
987
988 let port = reserve_free_tcp_port();
992 let spec = mock_service_spec_tcp_only_port(port);
993
994 stream_registry.register_tcp(port, StreamService::new("grpc-service".to_string(), vec![]));
996
997 let result = manager.ensure_ports_for_service(&spec, None).await;
999 assert!(result.is_ok());
1000
1001 let tcp_ports = manager.tcp_listeners.read().await;
1003 assert!(tcp_ports.contains(&port));
1004 }
1005
1006 #[tokio::test]
1007 async fn test_ensure_ports_tcp_without_stream_registry() {
1008 let config = ProxyManagerConfig::default();
1009 let registry = Arc::new(ServiceRegistry::new());
1010 let manager = ProxyManager::new(config, registry, None);
1011
1012 let spec = mock_service_spec_tcp_only();
1013
1014 let result = manager.ensure_ports_for_service(&spec, None).await;
1016 assert!(result.is_ok());
1017
1018 let tcp_ports = manager.tcp_listeners.read().await;
1020 assert!(tcp_ports.is_empty());
1021 }
1022
1023 #[tokio::test]
1024 async fn test_stream_registry_setter() {
1025 let stream_registry = Arc::new(StreamRegistry::new());
1026 let config = ProxyManagerConfig::default();
1027 let registry = Arc::new(ServiceRegistry::new());
1028 let mut manager = ProxyManager::new(config, registry, None);
1029
1030 assert!(manager.stream_registry().is_none());
1031 manager.set_stream_registry(stream_registry.clone());
1032 assert!(manager.stream_registry().is_some());
1033 }
1034
1035 #[tokio::test]
1036 async fn test_registry_accessor() {
1037 let config = ProxyManagerConfig::default();
1038 let registry = Arc::new(ServiceRegistry::new());
1039 let manager = ProxyManager::new(config, registry.clone(), None);
1040
1041 assert_eq!(Arc::as_ptr(&manager.registry()), Arc::as_ptr(®istry));
1043 }
1044}