1use async_trait::async_trait;
45use parking_lot::RwLock;
46use pingora::prelude::*;
47use pingora_load_balancing::discovery::{ServiceDiscovery, Static as StaticDiscovery};
48use pingora_load_balancing::Backend;
49use std::collections::{BTreeSet, HashMap};
50use std::net::ToSocketAddrs;
51use std::sync::Arc;
52use std::time::{Duration, Instant};
53use tracing::{debug, error, info, trace, warn};
54
55#[derive(Debug, Clone)]
57pub enum DiscoveryConfig {
58 Static {
60 backends: Vec<String>,
62 },
63 Dns {
65 hostname: String,
67 port: u16,
69 refresh_interval: Duration,
71 },
72 DnsSrv {
74 service: String,
76 refresh_interval: Duration,
78 },
79 Consul {
81 address: String,
83 service: String,
85 datacenter: Option<String>,
87 only_passing: bool,
89 refresh_interval: Duration,
91 tag: Option<String>,
93 },
94 Kubernetes {
96 namespace: String,
98 service: String,
100 port_name: Option<String>,
102 refresh_interval: Duration,
104 kubeconfig: Option<String>,
106 },
107 File {
109 path: String,
111 watch_interval: Duration,
113 },
114}
115
116impl Default for DiscoveryConfig {
117 fn default() -> Self {
118 Self::Static {
119 backends: vec!["127.0.0.1:8080".to_string()],
120 }
121 }
122}
123
124pub struct DnsDiscovery {
128 hostname: String,
129 port: u16,
130 refresh_interval: Duration,
131 cached_backends: RwLock<BTreeSet<Backend>>,
133 last_resolution: RwLock<Instant>,
135}
136
137impl DnsDiscovery {
138 pub fn new(hostname: String, port: u16, refresh_interval: Duration) -> Self {
140 Self {
141 hostname,
142 port,
143 refresh_interval,
144 cached_backends: RwLock::new(BTreeSet::new()),
145 last_resolution: RwLock::new(Instant::now() - refresh_interval),
146 }
147 }
148
149 fn resolve(&self) -> Result<BTreeSet<Backend>, Box<Error>> {
151 let address = format!("{}:{}", self.hostname, self.port);
152
153 trace!(
154 hostname = %self.hostname,
155 port = self.port,
156 "Resolving DNS for service discovery"
157 );
158
159 match address.to_socket_addrs() {
160 Ok(addrs) => {
161 let backends: BTreeSet<Backend> = addrs
162 .map(|addr| Backend {
163 addr: pingora_core::protocols::l4::socket::SocketAddr::Inet(addr),
164 weight: 1,
165 ext: http::Extensions::new(),
166 })
167 .collect();
168
169 debug!(
170 hostname = %self.hostname,
171 backend_count = backends.len(),
172 "DNS resolution successful"
173 );
174
175 Ok(backends)
176 }
177 Err(e) => {
178 error!(
179 hostname = %self.hostname,
180 error = %e,
181 "DNS resolution failed"
182 );
183 Err(Error::explain(
184 ErrorType::ConnectNoRoute,
185 format!("DNS resolution failed for {}: {}", self.hostname, e),
186 ))
187 }
188 }
189 }
190
191 fn needs_refresh(&self) -> bool {
193 let last = *self.last_resolution.read();
194 last.elapsed() >= self.refresh_interval
195 }
196}
197
198#[async_trait]
199impl ServiceDiscovery for DnsDiscovery {
200 async fn discover(&self) -> Result<(BTreeSet<Backend>, HashMap<u64, bool>)> {
201 if self.needs_refresh() {
203 match self.resolve() {
204 Ok(backends) => {
205 *self.cached_backends.write() = backends;
206 *self.last_resolution.write() = Instant::now();
207 }
208 Err(e) => {
209 let cached = self.cached_backends.read().clone();
211 if !cached.is_empty() {
212 warn!(
213 hostname = %self.hostname,
214 error = %e,
215 cached_count = cached.len(),
216 "DNS resolution failed, using cached backends"
217 );
218 return Ok((cached, HashMap::new()));
219 }
220 return Err(e);
221 }
222 }
223 }
224
225 let backends = self.cached_backends.read().clone();
226 Ok((backends, HashMap::new()))
227 }
228}
229
230pub struct ConsulDiscovery {
238 address: String,
240 service: String,
242 datacenter: Option<String>,
244 only_passing: bool,
246 refresh_interval: Duration,
248 tag: Option<String>,
250 cached_backends: RwLock<BTreeSet<Backend>>,
252 last_resolution: RwLock<Instant>,
254}
255
256impl ConsulDiscovery {
257 pub fn new(
259 address: String,
260 service: String,
261 datacenter: Option<String>,
262 only_passing: bool,
263 refresh_interval: Duration,
264 tag: Option<String>,
265 ) -> Self {
266 Self {
267 address,
268 service,
269 datacenter,
270 only_passing,
271 refresh_interval,
272 tag,
273 cached_backends: RwLock::new(BTreeSet::new()),
274 last_resolution: RwLock::new(Instant::now() - refresh_interval),
275 }
276 }
277
278 fn build_url(&self) -> String {
280 let mut url = format!(
281 "{}/v1/health/service/{}",
282 self.address.trim_end_matches('/'),
283 self.service
284 );
285
286 let mut params = Vec::new();
287 if self.only_passing {
288 params.push("passing=true".to_string());
289 }
290 if let Some(dc) = &self.datacenter {
291 params.push(format!("dc={}", dc));
292 }
293 if let Some(tag) = &self.tag {
294 params.push(format!("tag={}", tag));
295 }
296
297 if !params.is_empty() {
298 url.push('?');
299 url.push_str(¶ms.join("&"));
300 }
301
302 url
303 }
304
305 fn needs_refresh(&self) -> bool {
307 let last = *self.last_resolution.read();
308 last.elapsed() >= self.refresh_interval
309 }
310}
311
312#[async_trait]
313impl ServiceDiscovery for ConsulDiscovery {
314 async fn discover(&self) -> Result<(BTreeSet<Backend>, HashMap<u64, bool>)> {
315 if !self.needs_refresh() {
316 let backends = self.cached_backends.read().clone();
317 return Ok((backends, HashMap::new()));
318 }
319
320 let url = self.build_url();
321 trace!(
322 service = %self.service,
323 url = %url,
324 "Querying Consul for service discovery"
325 );
326
327 let result = tokio::task::spawn_blocking({
330 let url = url.clone();
331 let service = self.service.clone();
332 move || -> Result<BTreeSet<Backend>, Box<Error>> {
333 let url_parsed = url
336 .trim_start_matches("http://")
337 .trim_start_matches("https://");
338 let (host_port, path) = url_parsed.split_once('/').unwrap_or((url_parsed, ""));
339
340 let socket_addr = host_port
341 .to_socket_addrs()
342 .map_err(|e| {
343 Error::explain(
344 ErrorType::ConnectNoRoute,
345 format!("Failed to resolve Consul address: {}", e),
346 )
347 })?
348 .next()
349 .ok_or_else(|| {
350 Error::explain(
351 ErrorType::ConnectNoRoute,
352 "Failed to resolve Consul address",
353 )
354 })?;
355
356 let stream = match std::net::TcpStream::connect_timeout(
357 &socket_addr,
358 Duration::from_secs(5),
359 ) {
360 Ok(s) => s,
361 Err(e) => {
362 return Err(Error::explain(
363 ErrorType::ConnectTimedout,
364 format!("Failed to connect to Consul: {}", e),
365 ));
366 }
367 };
368
369 stream
370 .set_read_timeout(Some(Duration::from_secs(10)))
371 .map_err(|e| {
372 Error::explain(
373 ErrorType::InternalError,
374 format!("Failed to set read timeout: {}", e),
375 )
376 })?;
377 stream
378 .set_write_timeout(Some(Duration::from_secs(5)))
379 .map_err(|e| {
380 Error::explain(
381 ErrorType::InternalError,
382 format!("Failed to set write timeout: {}", e),
383 )
384 })?;
385
386 use std::io::{Read, Write};
387 let request = format!(
388 "GET /{} HTTP/1.1\r\nHost: {}\r\nConnection: close\r\n\r\n",
389 path, host_port
390 );
391
392 let mut stream = stream;
393 stream.write_all(request.as_bytes()).map_err(|e| {
394 Error::explain(
395 ErrorType::WriteError,
396 format!("Failed to send request: {}", e),
397 )
398 })?;
399
400 let mut response = String::new();
401 stream.read_to_string(&mut response).map_err(|e| {
402 Error::explain(
403 ErrorType::ReadError,
404 format!("Failed to read response: {}", e),
405 )
406 })?;
407
408 let body = response.split("\r\n\r\n").nth(1).unwrap_or("");
410
411 let backends = parse_consul_response(body, &service)?;
414
415 Ok(backends)
416 }
417 })
418 .await
419 .map_err(|e| Error::explain(ErrorType::InternalError, format!("Task failed: {}", e)))?;
420
421 match result {
422 Ok(backends) => {
423 info!(
424 service = %self.service,
425 backend_count = backends.len(),
426 "Consul discovery successful"
427 );
428 *self.cached_backends.write() = backends.clone();
429 *self.last_resolution.write() = Instant::now();
430 Ok((backends, HashMap::new()))
431 }
432 Err(e) => {
433 let cached = self.cached_backends.read().clone();
434 if !cached.is_empty() {
435 warn!(
436 service = %self.service,
437 error = %e,
438 cached_count = cached.len(),
439 "Consul query failed, using cached backends"
440 );
441 return Ok((cached, HashMap::new()));
442 }
443 Err(e)
444 }
445 }
446 }
447}
448
449fn parse_consul_response(body: &str, service_name: &str) -> Result<BTreeSet<Backend>, Box<Error>> {
451 let mut backends = BTreeSet::new();
454
455 let entries: Vec<&str> = body.split(r#""Service":"#).skip(1).collect();
457
458 for entry in entries {
459 let port = entry
461 .split(r#""Port":"#)
462 .nth(1)
463 .and_then(|s| s.split(|c: char| !c.is_ascii_digit()).next())
464 .and_then(|s| s.parse::<u16>().ok());
465
466 let service_addr = entry
468 .split(r#""Address":""#)
469 .nth(1)
470 .and_then(|s| s.split('"').next())
471 .filter(|s| !s.is_empty());
472
473 let node_addr = body
475 .split(r#""Node":"#)
476 .nth(1)
477 .and_then(|s| s.split(r#""Address":""#).nth(1))
478 .and_then(|s| s.split('"').next());
479
480 let address = service_addr.or(node_addr);
481
482 if let (Some(addr), Some(port)) = (address, port) {
483 let full_addr = format!("{}:{}", addr, port);
484 if let Ok(mut addrs) = full_addr.to_socket_addrs() {
485 if let Some(socket_addr) = addrs.next() {
486 backends.insert(Backend {
487 addr: pingora_core::protocols::l4::socket::SocketAddr::Inet(socket_addr),
488 weight: 1,
489 ext: http::Extensions::new(),
490 });
491 }
492 }
493 }
494 }
495
496 if backends.is_empty() && !body.starts_with("[]") && !body.is_empty() {
497 warn!(
498 service = %service_name,
499 body_len = body.len(),
500 "Failed to parse Consul response, no backends found"
501 );
502 }
503
504 Ok(backends)
505}
506
507pub struct KubernetesDiscovery {
516 namespace: String,
518 service: String,
520 port_name: Option<String>,
522 refresh_interval: Duration,
524 kubeconfig: Option<String>,
526 cached_backends: RwLock<BTreeSet<Backend>>,
528 last_resolution: RwLock<Instant>,
530}
531
532impl KubernetesDiscovery {
533 pub fn new(
535 namespace: String,
536 service: String,
537 port_name: Option<String>,
538 refresh_interval: Duration,
539 kubeconfig: Option<String>,
540 ) -> Self {
541 Self {
542 namespace,
543 service,
544 port_name,
545 refresh_interval,
546 kubeconfig,
547 cached_backends: RwLock::new(BTreeSet::new()),
548 last_resolution: RwLock::new(Instant::now() - refresh_interval),
549 }
550 }
551
552 fn needs_refresh(&self) -> bool {
554 let last = *self.last_resolution.read();
555 last.elapsed() >= self.refresh_interval
556 }
557
558 fn get_api_config(&self) -> Result<(String, String), Box<Error>> {
560 if self.kubeconfig.is_some() {
561 return Err(Error::explain(
563 ErrorType::InternalError,
564 "Kubeconfig parsing not yet implemented, use in-cluster config",
565 ));
566 }
567
568 let host = std::env::var("KUBERNETES_SERVICE_HOST").map_err(|_| {
570 Error::explain(
571 ErrorType::InternalError,
572 "KUBERNETES_SERVICE_HOST not set, not running in Kubernetes?",
573 )
574 })?;
575 let port = std::env::var("KUBERNETES_SERVICE_PORT").unwrap_or_else(|_| "443".to_string());
576 let token = std::fs::read_to_string("/var/run/secrets/kubernetes.io/serviceaccount/token")
577 .map_err(|e| {
578 Error::explain(
579 ErrorType::InternalError,
580 format!("Failed to read service account token: {}", e),
581 )
582 })?;
583
584 Ok((format!("https://{}:{}", host, port), token))
585 }
586}
587
588#[async_trait]
589impl ServiceDiscovery for KubernetesDiscovery {
590 async fn discover(&self) -> Result<(BTreeSet<Backend>, HashMap<u64, bool>)> {
591 if !self.needs_refresh() {
592 let backends = self.cached_backends.read().clone();
593 return Ok((backends, HashMap::new()));
594 }
595
596 trace!(
597 namespace = %self.namespace,
598 service = %self.service,
599 "Querying Kubernetes for endpoint discovery"
600 );
601
602 let (api_server, _token) = match self.get_api_config() {
604 Ok(config) => config,
605 Err(e) => {
606 let cached = self.cached_backends.read().clone();
607 if !cached.is_empty() {
608 warn!(
609 service = %self.service,
610 error = %e,
611 cached_count = cached.len(),
612 "Kubernetes config unavailable, using cached backends"
613 );
614 return Ok((cached, HashMap::new()));
615 }
616 return Err(e);
617 }
618 };
619
620 let url = format!(
622 "{}/api/v1/namespaces/{}/endpoints/{}",
623 api_server, self.namespace, self.service
624 );
625
626 debug!(
627 url = %url,
628 namespace = %self.namespace,
629 service = %self.service,
630 "Kubernetes endpoint URL constructed"
631 );
632
633 warn!(
637 service = %self.service,
638 "Kubernetes discovery requires full HTTP client - returning cached or empty"
639 );
640
641 let cached = self.cached_backends.read().clone();
642 if !cached.is_empty() {
643 return Ok((cached, HashMap::new()));
644 }
645
646 Ok((BTreeSet::new(), HashMap::new()))
648 }
649}
650
651pub struct DiscoveryManager {
656 discoveries: RwLock<HashMap<String, Arc<dyn ServiceDiscovery + Send + Sync>>>,
658}
659
660impl DiscoveryManager {
661 pub fn new() -> Self {
663 Self {
664 discoveries: RwLock::new(HashMap::new()),
665 }
666 }
667
668 pub fn register(&self, upstream_id: &str, config: DiscoveryConfig) -> Result<(), Box<Error>> {
670 let discovery: Arc<dyn ServiceDiscovery + Send + Sync> = match config {
671 DiscoveryConfig::Static { backends } => {
672 let backend_set = backends
673 .iter()
674 .filter_map(|addr| {
675 addr.to_socket_addrs()
676 .ok()
677 .and_then(|mut addrs| addrs.next())
678 .map(|addr| Backend {
679 addr: pingora_core::protocols::l4::socket::SocketAddr::Inet(addr),
680 weight: 1,
681 ext: http::Extensions::new(),
682 })
683 })
684 .collect();
685
686 info!(
687 upstream_id = %upstream_id,
688 backend_count = backends.len(),
689 "Registered static service discovery"
690 );
691
692 Arc::new(StaticWrapper(StaticDiscovery::new(backend_set)))
693 }
694 DiscoveryConfig::Dns {
695 hostname,
696 port,
697 refresh_interval,
698 } => {
699 info!(
700 upstream_id = %upstream_id,
701 hostname = %hostname,
702 port = port,
703 refresh_interval_secs = refresh_interval.as_secs(),
704 "Registered DNS service discovery"
705 );
706
707 Arc::new(DnsDiscovery::new(hostname, port, refresh_interval))
708 }
709 DiscoveryConfig::DnsSrv {
710 service,
711 refresh_interval,
712 } => {
713 info!(
714 upstream_id = %upstream_id,
715 service = %service,
716 refresh_interval_secs = refresh_interval.as_secs(),
717 "DNS SRV discovery not yet fully implemented, using DNS A record fallback"
718 );
719
720 let hostname = service
723 .split('.')
724 .skip_while(|s| s.starts_with('_'))
725 .collect::<Vec<_>>()
726 .join(".");
727 Arc::new(DnsDiscovery::new(hostname, 80, refresh_interval))
728 }
729 DiscoveryConfig::Consul {
730 address,
731 service,
732 datacenter,
733 only_passing,
734 refresh_interval,
735 tag,
736 } => {
737 info!(
738 upstream_id = %upstream_id,
739 address = %address,
740 service = %service,
741 datacenter = datacenter.as_deref().unwrap_or("default"),
742 only_passing = only_passing,
743 refresh_interval_secs = refresh_interval.as_secs(),
744 "Registered Consul service discovery"
745 );
746
747 Arc::new(ConsulDiscovery::new(
748 address,
749 service,
750 datacenter,
751 only_passing,
752 refresh_interval,
753 tag,
754 ))
755 }
756 DiscoveryConfig::Kubernetes {
757 namespace,
758 service,
759 port_name,
760 refresh_interval,
761 kubeconfig,
762 } => {
763 info!(
764 upstream_id = %upstream_id,
765 namespace = %namespace,
766 service = %service,
767 port_name = port_name.as_deref().unwrap_or("default"),
768 refresh_interval_secs = refresh_interval.as_secs(),
769 "Registered Kubernetes endpoint discovery"
770 );
771
772 Arc::new(KubernetesDiscovery::new(
773 namespace,
774 service,
775 port_name,
776 refresh_interval,
777 kubeconfig,
778 ))
779 }
780 DiscoveryConfig::File {
781 path,
782 watch_interval,
783 } => {
784 info!(
785 upstream_id = %upstream_id,
786 path = %path,
787 watch_interval_secs = watch_interval.as_secs(),
788 "File-based discovery not yet implemented, using empty static"
789 );
790
791 Arc::new(StaticWrapper(StaticDiscovery::new(BTreeSet::new())))
793 }
794 };
795
796 self.discoveries
797 .write()
798 .insert(upstream_id.to_string(), discovery);
799 Ok(())
800 }
801
802 pub fn get(&self, upstream_id: &str) -> Option<Arc<dyn ServiceDiscovery + Send + Sync>> {
804 self.discoveries.read().get(upstream_id).cloned()
805 }
806
807 pub async fn discover(
809 &self,
810 upstream_id: &str,
811 ) -> Option<Result<(BTreeSet<Backend>, HashMap<u64, bool>)>> {
812 let discovery = self.get(upstream_id)?;
813 Some(discovery.discover().await)
814 }
815
816 pub fn remove(&self, upstream_id: &str) {
818 self.discoveries.write().remove(upstream_id);
819 }
820
821 pub fn count(&self) -> usize {
823 self.discoveries.read().len()
824 }
825}
826
827impl Default for DiscoveryManager {
828 fn default() -> Self {
829 Self::new()
830 }
831}
832
833struct StaticWrapper(Box<StaticDiscovery>);
835
836#[async_trait]
837impl ServiceDiscovery for StaticWrapper {
838 async fn discover(&self) -> Result<(BTreeSet<Backend>, HashMap<u64, bool>)> {
839 self.0.discover().await
840 }
841}
842
843unsafe impl Send for StaticWrapper {}
845unsafe impl Sync for StaticWrapper {}
846
847#[cfg(test)]
848mod tests {
849 use super::*;
850
851 #[test]
852 fn test_discovery_config_default() {
853 let config = DiscoveryConfig::default();
854 match config {
855 DiscoveryConfig::Static { backends } => {
856 assert_eq!(backends.len(), 1);
857 assert_eq!(backends[0], "127.0.0.1:8080");
858 }
859 _ => panic!("Expected Static config"),
860 }
861 }
862
863 #[tokio::test]
864 async fn test_discovery_manager() {
865 let manager = DiscoveryManager::new();
866
867 manager
869 .register(
870 "test-upstream",
871 DiscoveryConfig::Static {
872 backends: vec!["127.0.0.1:8080".to_string(), "127.0.0.1:8081".to_string()],
873 },
874 )
875 .unwrap();
876
877 assert_eq!(manager.count(), 1);
878
879 let result = manager.discover("test-upstream").await;
881 assert!(result.is_some());
882 let (backends, _) = result.unwrap().unwrap();
883 assert_eq!(backends.len(), 2);
884 }
885
886 #[test]
887 fn test_dns_discovery_needs_refresh() {
888 let discovery = DnsDiscovery::new(
889 "localhost".to_string(),
890 8080,
891 Duration::from_secs(0), );
893
894 assert!(discovery.needs_refresh());
896 }
897
898 #[test]
899 fn test_consul_discovery_url_building() {
900 let discovery = ConsulDiscovery::new(
901 "http://localhost:8500".to_string(),
902 "my-service".to_string(),
903 Some("dc1".to_string()),
904 true,
905 Duration::from_secs(10),
906 Some("production".to_string()),
907 );
908
909 let url = discovery.build_url();
910 assert!(url.starts_with("http://localhost:8500/v1/health/service/my-service"));
911 assert!(url.contains("passing=true"));
912 assert!(url.contains("dc=dc1"));
913 assert!(url.contains("tag=production"));
914 }
915
916 #[test]
917 fn test_consul_discovery_url_minimal() {
918 let discovery = ConsulDiscovery::new(
919 "http://consul.local:8500".to_string(),
920 "backend".to_string(),
921 None,
922 false,
923 Duration::from_secs(30),
924 None,
925 );
926
927 let url = discovery.build_url();
928 assert_eq!(url, "http://consul.local:8500/v1/health/service/backend");
929 }
930
931 #[test]
932 fn test_kubernetes_discovery_config() {
933 let discovery = KubernetesDiscovery::new(
934 "default".to_string(),
935 "my-service".to_string(),
936 Some("http".to_string()),
937 Duration::from_secs(10),
938 None,
939 );
940
941 assert!(discovery.needs_refresh());
943 }
944
945 #[test]
946 fn test_parse_consul_response_empty() {
947 let body = "[]";
948 let backends = parse_consul_response(body, "test").unwrap();
949 assert!(backends.is_empty());
950 }
951
952 #[tokio::test]
953 async fn test_discovery_manager_consul() {
954 let manager = DiscoveryManager::new();
955
956 manager
958 .register(
959 "consul-upstream",
960 DiscoveryConfig::Consul {
961 address: "http://localhost:8500".to_string(),
962 service: "my-service".to_string(),
963 datacenter: Some("dc1".to_string()),
964 only_passing: true,
965 refresh_interval: Duration::from_secs(10),
966 tag: None,
967 },
968 )
969 .unwrap();
970
971 assert_eq!(manager.count(), 1);
972 assert!(manager.get("consul-upstream").is_some());
973 }
974
975 #[tokio::test]
976 async fn test_discovery_manager_kubernetes() {
977 let manager = DiscoveryManager::new();
978
979 manager
981 .register(
982 "k8s-upstream",
983 DiscoveryConfig::Kubernetes {
984 namespace: "production".to_string(),
985 service: "api-server".to_string(),
986 port_name: Some("http".to_string()),
987 refresh_interval: Duration::from_secs(15),
988 kubeconfig: None,
989 },
990 )
991 .unwrap();
992
993 assert_eq!(manager.count(), 1);
994 assert!(manager.get("k8s-upstream").is_some());
995 }
996}