1use async_trait::async_trait;
45use parking_lot::RwLock;
46use pingora::prelude::*;
47use pingora_load_balancing::discovery::{ServiceDiscovery, Static as StaticDiscovery};
48use pingora_load_balancing::Backend;
49use std::collections::{BTreeSet, HashMap};
50use std::net::ToSocketAddrs;
51use std::sync::Arc;
52use std::time::{Duration, Instant};
53use tracing::{debug, error, info, trace, warn};
54
55#[derive(Debug, Clone)]
57pub enum DiscoveryConfig {
58 Static {
60 backends: Vec<String>,
62 },
63 Dns {
65 hostname: String,
67 port: u16,
69 refresh_interval: Duration,
71 },
72 DnsSrv {
74 service: String,
76 refresh_interval: Duration,
78 },
79 Consul {
81 address: String,
83 service: String,
85 datacenter: Option<String>,
87 only_passing: bool,
89 refresh_interval: Duration,
91 tag: Option<String>,
93 },
94 Kubernetes {
96 namespace: String,
98 service: String,
100 port_name: Option<String>,
102 refresh_interval: Duration,
104 kubeconfig: Option<String>,
106 },
107 File {
109 path: String,
111 watch_interval: Duration,
113 },
114}
115
116impl Default for DiscoveryConfig {
117 fn default() -> Self {
118 Self::Static {
119 backends: vec!["127.0.0.1:8080".to_string()],
120 }
121 }
122}
123
124pub struct DnsDiscovery {
128 hostname: String,
129 port: u16,
130 refresh_interval: Duration,
131 cached_backends: RwLock<BTreeSet<Backend>>,
133 last_resolution: RwLock<Instant>,
135}
136
137impl DnsDiscovery {
138 pub fn new(hostname: String, port: u16, refresh_interval: Duration) -> Self {
140 Self {
141 hostname,
142 port,
143 refresh_interval,
144 cached_backends: RwLock::new(BTreeSet::new()),
145 last_resolution: RwLock::new(Instant::now() - refresh_interval),
146 }
147 }
148
149 fn resolve(&self) -> Result<BTreeSet<Backend>, Box<Error>> {
151 let address = format!("{}:{}", self.hostname, self.port);
152
153 trace!(
154 hostname = %self.hostname,
155 port = self.port,
156 "Resolving DNS for service discovery"
157 );
158
159 match address.to_socket_addrs() {
160 Ok(addrs) => {
161 let backends: BTreeSet<Backend> = addrs
162 .map(|addr| Backend {
163 addr: pingora_core::protocols::l4::socket::SocketAddr::Inet(addr),
164 weight: 1,
165 ext: http::Extensions::new(),
166 })
167 .collect();
168
169 debug!(
170 hostname = %self.hostname,
171 backend_count = backends.len(),
172 "DNS resolution successful"
173 );
174
175 Ok(backends)
176 }
177 Err(e) => {
178 error!(
179 hostname = %self.hostname,
180 error = %e,
181 "DNS resolution failed"
182 );
183 Err(Error::explain(
184 ErrorType::ConnectNoRoute,
185 format!("DNS resolution failed for {}: {}", self.hostname, e),
186 ))
187 }
188 }
189 }
190
191 fn needs_refresh(&self) -> bool {
193 let last = *self.last_resolution.read();
194 last.elapsed() >= self.refresh_interval
195 }
196}
197
198#[async_trait]
199impl ServiceDiscovery for DnsDiscovery {
200 async fn discover(&self) -> Result<(BTreeSet<Backend>, HashMap<u64, bool>)> {
201 if self.needs_refresh() {
203 match self.resolve() {
204 Ok(backends) => {
205 *self.cached_backends.write() = backends;
206 *self.last_resolution.write() = Instant::now();
207 }
208 Err(e) => {
209 let cached = self.cached_backends.read().clone();
211 if !cached.is_empty() {
212 warn!(
213 hostname = %self.hostname,
214 error = %e,
215 cached_count = cached.len(),
216 "DNS resolution failed, using cached backends"
217 );
218 return Ok((cached, HashMap::new()));
219 }
220 return Err(e);
221 }
222 }
223 }
224
225 let backends = self.cached_backends.read().clone();
226 Ok((backends, HashMap::new()))
227 }
228}
229
230pub struct ConsulDiscovery {
238 address: String,
240 service: String,
242 datacenter: Option<String>,
244 only_passing: bool,
246 refresh_interval: Duration,
248 tag: Option<String>,
250 cached_backends: RwLock<BTreeSet<Backend>>,
252 last_resolution: RwLock<Instant>,
254}
255
256impl ConsulDiscovery {
257 pub fn new(
259 address: String,
260 service: String,
261 datacenter: Option<String>,
262 only_passing: bool,
263 refresh_interval: Duration,
264 tag: Option<String>,
265 ) -> Self {
266 Self {
267 address,
268 service,
269 datacenter,
270 only_passing,
271 refresh_interval,
272 tag,
273 cached_backends: RwLock::new(BTreeSet::new()),
274 last_resolution: RwLock::new(Instant::now() - refresh_interval),
275 }
276 }
277
278 fn build_url(&self) -> String {
280 let mut url = format!(
281 "{}/v1/health/service/{}",
282 self.address.trim_end_matches('/'),
283 self.service
284 );
285
286 let mut params = Vec::new();
287 if self.only_passing {
288 params.push("passing=true".to_string());
289 }
290 if let Some(dc) = &self.datacenter {
291 params.push(format!("dc={}", dc));
292 }
293 if let Some(tag) = &self.tag {
294 params.push(format!("tag={}", tag));
295 }
296
297 if !params.is_empty() {
298 url.push('?');
299 url.push_str(¶ms.join("&"));
300 }
301
302 url
303 }
304
305 fn needs_refresh(&self) -> bool {
307 let last = *self.last_resolution.read();
308 last.elapsed() >= self.refresh_interval
309 }
310}
311
312#[async_trait]
313impl ServiceDiscovery for ConsulDiscovery {
314 async fn discover(&self) -> Result<(BTreeSet<Backend>, HashMap<u64, bool>)> {
315 if !self.needs_refresh() {
316 let backends = self.cached_backends.read().clone();
317 return Ok((backends, HashMap::new()));
318 }
319
320 let url = self.build_url();
321 trace!(
322 service = %self.service,
323 url = %url,
324 "Querying Consul for service discovery"
325 );
326
327 let result = tokio::task::spawn_blocking({
330 let url = url.clone();
331 let service = self.service.clone();
332 move || -> Result<BTreeSet<Backend>, Box<Error>> {
333 let url_parsed = url
336 .trim_start_matches("http://")
337 .trim_start_matches("https://");
338 let (host_port, path) = url_parsed.split_once('/').unwrap_or((url_parsed, ""));
339
340 let socket_addr = host_port
341 .to_socket_addrs()
342 .map_err(|e| {
343 Error::explain(
344 ErrorType::ConnectNoRoute,
345 format!("Failed to resolve Consul address: {}", e),
346 )
347 })?
348 .next()
349 .ok_or_else(|| {
350 Error::explain(
351 ErrorType::ConnectNoRoute,
352 "Failed to resolve Consul address",
353 )
354 })?;
355
356 let stream = match std::net::TcpStream::connect_timeout(
357 &socket_addr,
358 Duration::from_secs(5),
359 ) {
360 Ok(s) => s,
361 Err(e) => {
362 return Err(Error::explain(
363 ErrorType::ConnectTimedout,
364 format!("Failed to connect to Consul: {}", e),
365 ));
366 }
367 };
368
369 stream
370 .set_read_timeout(Some(Duration::from_secs(10)))
371 .map_err(|e| {
372 Error::explain(
373 ErrorType::InternalError,
374 format!("Failed to set read timeout: {}", e),
375 )
376 })?;
377 stream
378 .set_write_timeout(Some(Duration::from_secs(5)))
379 .map_err(|e| {
380 Error::explain(
381 ErrorType::InternalError,
382 format!("Failed to set write timeout: {}", e),
383 )
384 })?;
385
386 use std::io::{Read, Write};
387 let request = format!(
388 "GET /{} HTTP/1.1\r\nHost: {}\r\nConnection: close\r\n\r\n",
389 path, host_port
390 );
391
392 let mut stream = stream;
393 stream.write_all(request.as_bytes()).map_err(|e| {
394 Error::explain(
395 ErrorType::WriteError,
396 format!("Failed to send request: {}", e),
397 )
398 })?;
399
400 let mut response = String::new();
401 stream.read_to_string(&mut response).map_err(|e| {
402 Error::explain(
403 ErrorType::ReadError,
404 format!("Failed to read response: {}", e),
405 )
406 })?;
407
408 let body = response.split("\r\n\r\n").nth(1).unwrap_or("");
410
411 let backends = parse_consul_response(body, &service)?;
414
415 Ok(backends)
416 }
417 })
418 .await
419 .map_err(|e| Error::explain(ErrorType::InternalError, format!("Task failed: {}", e)))?;
420
421 match result {
422 Ok(backends) => {
423 info!(
424 service = %self.service,
425 backend_count = backends.len(),
426 "Consul discovery successful"
427 );
428 *self.cached_backends.write() = backends.clone();
429 *self.last_resolution.write() = Instant::now();
430 Ok((backends, HashMap::new()))
431 }
432 Err(e) => {
433 let cached = self.cached_backends.read().clone();
434 if !cached.is_empty() {
435 warn!(
436 service = %self.service,
437 error = %e,
438 cached_count = cached.len(),
439 "Consul query failed, using cached backends"
440 );
441 return Ok((cached, HashMap::new()));
442 }
443 Err(e)
444 }
445 }
446 }
447}
448
449fn parse_consul_response(body: &str, service_name: &str) -> Result<BTreeSet<Backend>, Box<Error>> {
451 let mut backends = BTreeSet::new();
454
455 let entries: Vec<&str> = body.split(r#""Service":"#).skip(1).collect();
457
458 for entry in entries {
459 let port = entry
461 .split(r#""Port":"#)
462 .nth(1)
463 .and_then(|s| s.split(|c: char| !c.is_ascii_digit()).next())
464 .and_then(|s| s.parse::<u16>().ok());
465
466 let service_addr = entry
468 .split(r#""Address":""#)
469 .nth(1)
470 .and_then(|s| s.split('"').next())
471 .filter(|s| !s.is_empty());
472
473 let node_addr = body
475 .split(r#""Node":"#)
476 .nth(1)
477 .and_then(|s| s.split(r#""Address":""#).nth(1))
478 .and_then(|s| s.split('"').next());
479
480 let address = service_addr.or(node_addr);
481
482 if let (Some(addr), Some(port)) = (address, port) {
483 let full_addr = format!("{}:{}", addr, port);
484 if let Ok(mut addrs) = full_addr.to_socket_addrs() {
485 if let Some(socket_addr) = addrs.next() {
486 backends.insert(Backend {
487 addr: pingora_core::protocols::l4::socket::SocketAddr::Inet(socket_addr),
488 weight: 1,
489 ext: http::Extensions::new(),
490 });
491 }
492 }
493 }
494 }
495
496 if backends.is_empty() && !body.starts_with("[]") && !body.is_empty() {
497 warn!(
498 service = %service_name,
499 body_len = body.len(),
500 "Failed to parse Consul response, no backends found"
501 );
502 }
503
504 Ok(backends)
505}
506
507#[cfg(feature = "kubernetes")]
512use crate::kubeconfig::{KubeAuth, Kubeconfig, ResolvedKubeConfig};
513
514pub struct KubernetesDiscovery {
538 namespace: String,
540 service: String,
542 port_name: Option<String>,
544 refresh_interval: Duration,
546 kubeconfig: Option<String>,
548 cached_backends: RwLock<BTreeSet<Backend>>,
550 last_resolution: RwLock<Instant>,
552 #[cfg(feature = "kubernetes")]
554 resolved_config: RwLock<Option<ResolvedKubeConfig>>,
555}
556
557impl KubernetesDiscovery {
558 pub fn new(
560 namespace: String,
561 service: String,
562 port_name: Option<String>,
563 refresh_interval: Duration,
564 kubeconfig: Option<String>,
565 ) -> Self {
566 Self {
567 namespace,
568 service,
569 port_name,
570 refresh_interval,
571 kubeconfig,
572 cached_backends: RwLock::new(BTreeSet::new()),
573 last_resolution: RwLock::new(Instant::now() - refresh_interval),
574 #[cfg(feature = "kubernetes")]
575 resolved_config: RwLock::new(None),
576 }
577 }
578
579 fn needs_refresh(&self) -> bool {
581 let last = *self.last_resolution.read();
582 last.elapsed() >= self.refresh_interval
583 }
584
585 fn get_in_cluster_config(&self) -> Result<(String, String), Box<Error>> {
587 let host = std::env::var("KUBERNETES_SERVICE_HOST").map_err(|_| {
588 Error::explain(
589 ErrorType::InternalError,
590 "KUBERNETES_SERVICE_HOST not set, not running in Kubernetes?",
591 )
592 })?;
593 let port = std::env::var("KUBERNETES_SERVICE_PORT").unwrap_or_else(|_| "443".to_string());
594 let token = std::fs::read_to_string("/var/run/secrets/kubernetes.io/serviceaccount/token")
595 .map_err(|e| {
596 Error::explain(
597 ErrorType::InternalError,
598 format!("Failed to read service account token: {}", e),
599 )
600 })?;
601
602 Ok((format!("https://{}:{}", host, port), token.trim().to_string()))
603 }
604
605 #[cfg(feature = "kubernetes")]
607 fn load_kubeconfig(&self) -> Result<ResolvedKubeConfig, Box<Error>> {
608 if let Some(config) = self.resolved_config.read().as_ref() {
610 return Ok(config.clone());
611 }
612
613 let kubeconfig = if let Some(path) = &self.kubeconfig {
614 Kubeconfig::from_file(path).map_err(|e| {
615 Error::explain(
616 ErrorType::InternalError,
617 format!("Failed to load kubeconfig from {}: {}", path, e),
618 )
619 })?
620 } else {
621 Kubeconfig::from_default_location().map_err(|e| {
622 Error::explain(
623 ErrorType::InternalError,
624 format!("Failed to load kubeconfig from default location: {}", e),
625 )
626 })?
627 };
628
629 let resolved = kubeconfig.resolve_current().map_err(|e| {
630 Error::explain(
631 ErrorType::InternalError,
632 format!("Failed to resolve kubeconfig context: {}", e),
633 )
634 })?;
635
636 *self.resolved_config.write() = Some(resolved.clone());
638
639 Ok(resolved)
640 }
641}
642
643#[cfg(feature = "kubernetes")]
645mod k8s_types {
646 use serde::Deserialize;
647
648 #[derive(Debug, Deserialize)]
649 pub struct Endpoints {
650 pub subsets: Option<Vec<EndpointSubset>>,
651 }
652
653 #[derive(Debug, Deserialize)]
654 pub struct EndpointSubset {
655 pub addresses: Option<Vec<EndpointAddress>>,
656 pub ports: Option<Vec<EndpointPort>>,
657 }
658
659 #[derive(Debug, Deserialize)]
660 pub struct EndpointAddress {
661 pub ip: String,
662 pub hostname: Option<String>,
663 }
664
665 #[derive(Debug, Deserialize)]
666 pub struct EndpointPort {
667 pub name: Option<String>,
668 pub port: u16,
669 pub protocol: Option<String>,
670 }
671}
672
673#[cfg(feature = "kubernetes")]
674#[async_trait]
675impl ServiceDiscovery for KubernetesDiscovery {
676 async fn discover(&self) -> Result<(BTreeSet<Backend>, HashMap<u64, bool>)> {
677 if !self.needs_refresh() {
678 let backends = self.cached_backends.read().clone();
679 return Ok((backends, HashMap::new()));
680 }
681
682 trace!(
683 namespace = %self.namespace,
684 service = %self.service,
685 "Querying Kubernetes for endpoint discovery"
686 );
687
688 let (api_server, auth, ca_cert, skip_verify) = if self.kubeconfig.is_some() {
690 let config = self.load_kubeconfig()?;
691 (config.server, config.auth, config.ca_cert, config.insecure_skip_tls_verify)
692 } else {
693 match self.get_in_cluster_config() {
695 Ok((server, token)) => {
696 let ca = std::fs::read("/var/run/secrets/kubernetes.io/serviceaccount/ca.crt").ok();
698 (server, KubeAuth::Token(token), ca, false)
699 }
700 Err(e) => {
701 debug!(
703 error = %e,
704 "In-cluster config not available, trying default kubeconfig"
705 );
706 let config = self.load_kubeconfig()?;
707 (config.server, config.auth, config.ca_cert, config.insecure_skip_tls_verify)
708 }
709 }
710 };
711
712 let url = format!(
714 "{}/api/v1/namespaces/{}/endpoints/{}",
715 api_server.trim_end_matches('/'),
716 self.namespace,
717 self.service
718 );
719
720 debug!(
721 url = %url,
722 namespace = %self.namespace,
723 service = %self.service,
724 "Fetching Kubernetes endpoints"
725 );
726
727 let client_builder = reqwest::Client::builder()
729 .timeout(Duration::from_secs(10))
730 .danger_accept_invalid_certs(skip_verify);
731
732 let client_builder = if let Some(ca_data) = ca_cert {
734 let cert = reqwest::Certificate::from_pem(&ca_data).map_err(|e| {
735 Error::explain(
736 ErrorType::InternalError,
737 format!("Failed to parse CA certificate: {}", e),
738 )
739 })?;
740 client_builder.add_root_certificate(cert)
741 } else {
742 client_builder
743 };
744
745 let client_builder = match &auth {
747 KubeAuth::ClientCert { cert, key } => {
748 let mut identity_pem = cert.clone();
750 identity_pem.extend_from_slice(key);
751 let identity = reqwest::Identity::from_pem(&identity_pem).map_err(|e| {
752 Error::explain(
753 ErrorType::InternalError,
754 format!("Failed to create client identity: {}", e),
755 )
756 })?;
757 client_builder.identity(identity)
758 }
759 _ => client_builder,
760 };
761
762 let client = client_builder.build().map_err(|e| {
763 Error::explain(
764 ErrorType::InternalError,
765 format!("Failed to create HTTP client: {}", e),
766 )
767 })?;
768
769 let mut request = client.get(&url);
771 if let KubeAuth::Token(token) = &auth {
772 request = request.bearer_auth(token);
773 }
774
775 let response = request.send().await.map_err(|e| {
777 Error::explain(
778 ErrorType::ConnectError,
779 format!("Failed to connect to Kubernetes API: {}", e),
780 )
781 })?;
782
783 if !response.status().is_success() {
784 let status = response.status();
785 let body = response.text().await.unwrap_or_default();
786 return Err(Error::explain(
787 ErrorType::HTTPStatus(status.as_u16()),
788 format!("Kubernetes API returned {}: {}", status, body),
789 ));
790 }
791
792 let endpoints: k8s_types::Endpoints = response.json().await.map_err(|e| {
794 Error::explain(
795 ErrorType::InternalError,
796 format!("Failed to parse Kubernetes endpoints: {}", e),
797 )
798 })?;
799
800 let mut backends = BTreeSet::new();
802 if let Some(subsets) = endpoints.subsets {
803 for subset in subsets {
804 let target_port = subset.ports.as_ref().and_then(|ports| {
806 if let Some(port_name) = &self.port_name {
807 ports.iter().find(|p| p.name.as_ref() == Some(port_name)).map(|p| p.port)
809 } else {
810 ports.first().map(|p| p.port)
812 }
813 });
814
815 if let (Some(addresses), Some(port)) = (subset.addresses, target_port) {
816 for addr in addresses {
817 let socket_addr = format!("{}:{}", addr.ip, port);
818 if let Ok(mut addrs) = socket_addr.to_socket_addrs() {
819 if let Some(socket_addr) = addrs.next() {
820 backends.insert(Backend {
821 addr: pingora_core::protocols::l4::socket::SocketAddr::Inet(socket_addr),
822 weight: 1,
823 ext: http::Extensions::new(),
824 });
825 }
826 }
827 }
828 }
829 }
830 }
831
832 info!(
833 service = %self.service,
834 namespace = %self.namespace,
835 backend_count = backends.len(),
836 "Kubernetes endpoint discovery successful"
837 );
838
839 *self.cached_backends.write() = backends.clone();
841 *self.last_resolution.write() = Instant::now();
842
843 Ok((backends, HashMap::new()))
844 }
845}
846
847#[cfg(not(feature = "kubernetes"))]
849#[async_trait]
850impl ServiceDiscovery for KubernetesDiscovery {
851 async fn discover(&self) -> Result<(BTreeSet<Backend>, HashMap<u64, bool>)> {
852 if !self.needs_refresh() {
853 let backends = self.cached_backends.read().clone();
854 return Ok((backends, HashMap::new()));
855 }
856
857 if self.kubeconfig.is_none() {
859 if let Ok((_, _)) = self.get_in_cluster_config() {
860 warn!(
861 service = %self.service,
862 "Kubernetes discovery requires 'kubernetes' feature flag for full support"
863 );
864 }
865 } else {
866 warn!(
867 service = %self.service,
868 kubeconfig = ?self.kubeconfig,
869 "Kubeconfig support requires 'kubernetes' feature flag"
870 );
871 }
872
873 let cached = self.cached_backends.read().clone();
874 Ok((cached, HashMap::new()))
875 }
876}
877
878pub struct FileDiscovery {
911 path: String,
913 watch_interval: Duration,
915 cached_backends: RwLock<BTreeSet<Backend>>,
917 last_check: RwLock<Instant>,
919 last_modified: RwLock<Option<std::time::SystemTime>>,
921}
922
923impl FileDiscovery {
924 pub fn new(path: String, watch_interval: Duration) -> Self {
926 Self {
927 path,
928 watch_interval,
929 cached_backends: RwLock::new(BTreeSet::new()),
930 last_check: RwLock::new(Instant::now() - watch_interval),
931 last_modified: RwLock::new(None),
932 }
933 }
934
935 fn needs_check(&self) -> bool {
937 let last = *self.last_check.read();
938 last.elapsed() >= self.watch_interval
939 }
940
941 fn file_modified(&self) -> bool {
943 let metadata = match std::fs::metadata(&self.path) {
944 Ok(m) => m,
945 Err(_) => return true, };
947
948 let modified = match metadata.modified() {
949 Ok(m) => m,
950 Err(_) => return true,
951 };
952
953 let last_known = *self.last_modified.read();
954 match last_known {
955 Some(last) => modified > last,
956 None => true, }
958 }
959
960 fn read_backends(&self) -> Result<BTreeSet<Backend>, Box<Error>> {
962 trace!(path = %self.path, "Reading backends from file");
963
964 let content = std::fs::read_to_string(&self.path).map_err(|e| {
965 Error::explain(
966 ErrorType::ReadError,
967 format!("Failed to read backends file '{}': {}", self.path, e),
968 )
969 })?;
970
971 if let Ok(metadata) = std::fs::metadata(&self.path) {
973 if let Ok(modified) = metadata.modified() {
974 *self.last_modified.write() = Some(modified);
975 }
976 }
977
978 let mut backends = BTreeSet::new();
979 let mut line_num = 0;
980
981 for line in content.lines() {
982 line_num += 1;
983 let line = line.trim();
984
985 if line.is_empty() || line.starts_with('#') {
987 continue;
988 }
989
990 let (address, weight) = Self::parse_backend_line(line, line_num)?;
992
993 match address.to_socket_addrs() {
995 Ok(mut addrs) => {
996 if let Some(socket_addr) = addrs.next() {
997 backends.insert(Backend {
998 addr: pingora_core::protocols::l4::socket::SocketAddr::Inet(socket_addr),
999 weight,
1000 ext: http::Extensions::new(),
1001 });
1002 trace!(
1003 address = %address,
1004 weight = weight,
1005 "Added backend from file"
1006 );
1007 } else {
1008 warn!(
1009 path = %self.path,
1010 line = line_num,
1011 address = %address,
1012 "Address resolved but no socket address found"
1013 );
1014 }
1015 }
1016 Err(e) => {
1017 warn!(
1018 path = %self.path,
1019 line = line_num,
1020 address = %address,
1021 error = %e,
1022 "Failed to resolve backend address, skipping"
1023 );
1024 }
1025 }
1026 }
1027
1028 debug!(
1029 path = %self.path,
1030 backend_count = backends.len(),
1031 "Loaded backends from file"
1032 );
1033
1034 Ok(backends)
1035 }
1036
1037 fn parse_backend_line(line: &str, line_num: usize) -> Result<(String, usize), Box<Error>> {
1041 let parts: Vec<&str> = line.split_whitespace().collect();
1042
1043 if parts.is_empty() {
1044 return Err(Error::explain(
1045 ErrorType::InternalError,
1046 format!("Empty backend line at line {}", line_num),
1047 ));
1048 }
1049
1050 let address = parts[0].to_string();
1051 let mut weight = 1usize;
1052
1053 for part in parts.iter().skip(1) {
1055 if let Some(weight_str) = part.strip_prefix("weight=") {
1056 weight = weight_str.parse().unwrap_or_else(|_| {
1057 warn!(
1058 line = line_num,
1059 weight = weight_str,
1060 "Invalid weight value, using default 1"
1061 );
1062 1
1063 });
1064 }
1065 }
1066
1067 Ok((address, weight))
1068 }
1069}
1070
1071#[async_trait]
1072impl ServiceDiscovery for FileDiscovery {
1073 async fn discover(&self) -> Result<(BTreeSet<Backend>, HashMap<u64, bool>)> {
1074 if self.needs_check() {
1076 *self.last_check.write() = Instant::now();
1077
1078 if self.file_modified() {
1080 match self.read_backends() {
1081 Ok(backends) => {
1082 info!(
1083 path = %self.path,
1084 backend_count = backends.len(),
1085 "File-based discovery updated backends"
1086 );
1087 *self.cached_backends.write() = backends;
1088 }
1089 Err(e) => {
1090 let cached = self.cached_backends.read().clone();
1092 if !cached.is_empty() {
1093 warn!(
1094 path = %self.path,
1095 error = %e,
1096 cached_count = cached.len(),
1097 "File read failed, using cached backends"
1098 );
1099 return Ok((cached, HashMap::new()));
1100 }
1101 return Err(e);
1102 }
1103 }
1104 }
1105 }
1106
1107 let backends = self.cached_backends.read().clone();
1108 Ok((backends, HashMap::new()))
1109 }
1110}
1111
1112pub struct DiscoveryManager {
1121 discoveries: RwLock<HashMap<String, Arc<dyn ServiceDiscovery + Send + Sync>>>,
1123}
1124
1125impl DiscoveryManager {
1126 pub fn new() -> Self {
1128 Self {
1129 discoveries: RwLock::new(HashMap::new()),
1130 }
1131 }
1132
1133 pub fn register(&self, upstream_id: &str, config: DiscoveryConfig) -> Result<(), Box<Error>> {
1135 let discovery: Arc<dyn ServiceDiscovery + Send + Sync> = match config {
1136 DiscoveryConfig::Static { backends } => {
1137 let backend_set = backends
1138 .iter()
1139 .filter_map(|addr| {
1140 addr.to_socket_addrs()
1141 .ok()
1142 .and_then(|mut addrs| addrs.next())
1143 .map(|addr| Backend {
1144 addr: pingora_core::protocols::l4::socket::SocketAddr::Inet(addr),
1145 weight: 1,
1146 ext: http::Extensions::new(),
1147 })
1148 })
1149 .collect();
1150
1151 info!(
1152 upstream_id = %upstream_id,
1153 backend_count = backends.len(),
1154 "Registered static service discovery"
1155 );
1156
1157 Arc::new(StaticWrapper(StaticDiscovery::new(backend_set)))
1158 }
1159 DiscoveryConfig::Dns {
1160 hostname,
1161 port,
1162 refresh_interval,
1163 } => {
1164 info!(
1165 upstream_id = %upstream_id,
1166 hostname = %hostname,
1167 port = port,
1168 refresh_interval_secs = refresh_interval.as_secs(),
1169 "Registered DNS service discovery"
1170 );
1171
1172 Arc::new(DnsDiscovery::new(hostname, port, refresh_interval))
1173 }
1174 DiscoveryConfig::DnsSrv {
1175 service,
1176 refresh_interval,
1177 } => {
1178 info!(
1179 upstream_id = %upstream_id,
1180 service = %service,
1181 refresh_interval_secs = refresh_interval.as_secs(),
1182 "DNS SRV discovery not yet fully implemented, using DNS A record fallback"
1183 );
1184
1185 let hostname = service
1188 .split('.')
1189 .skip_while(|s| s.starts_with('_'))
1190 .collect::<Vec<_>>()
1191 .join(".");
1192 Arc::new(DnsDiscovery::new(hostname, 80, refresh_interval))
1193 }
1194 DiscoveryConfig::Consul {
1195 address,
1196 service,
1197 datacenter,
1198 only_passing,
1199 refresh_interval,
1200 tag,
1201 } => {
1202 info!(
1203 upstream_id = %upstream_id,
1204 address = %address,
1205 service = %service,
1206 datacenter = datacenter.as_deref().unwrap_or("default"),
1207 only_passing = only_passing,
1208 refresh_interval_secs = refresh_interval.as_secs(),
1209 "Registered Consul service discovery"
1210 );
1211
1212 Arc::new(ConsulDiscovery::new(
1213 address,
1214 service,
1215 datacenter,
1216 only_passing,
1217 refresh_interval,
1218 tag,
1219 ))
1220 }
1221 DiscoveryConfig::Kubernetes {
1222 namespace,
1223 service,
1224 port_name,
1225 refresh_interval,
1226 kubeconfig,
1227 } => {
1228 info!(
1229 upstream_id = %upstream_id,
1230 namespace = %namespace,
1231 service = %service,
1232 port_name = port_name.as_deref().unwrap_or("default"),
1233 refresh_interval_secs = refresh_interval.as_secs(),
1234 "Registered Kubernetes endpoint discovery"
1235 );
1236
1237 Arc::new(KubernetesDiscovery::new(
1238 namespace,
1239 service,
1240 port_name,
1241 refresh_interval,
1242 kubeconfig,
1243 ))
1244 }
1245 DiscoveryConfig::File {
1246 path,
1247 watch_interval,
1248 } => {
1249 info!(
1250 upstream_id = %upstream_id,
1251 path = %path,
1252 watch_interval_secs = watch_interval.as_secs(),
1253 "Registered file-based service discovery"
1254 );
1255
1256 Arc::new(FileDiscovery::new(path, watch_interval))
1257 }
1258 };
1259
1260 self.discoveries
1261 .write()
1262 .insert(upstream_id.to_string(), discovery);
1263 Ok(())
1264 }
1265
1266 pub fn get(&self, upstream_id: &str) -> Option<Arc<dyn ServiceDiscovery + Send + Sync>> {
1268 self.discoveries.read().get(upstream_id).cloned()
1269 }
1270
1271 pub async fn discover(
1273 &self,
1274 upstream_id: &str,
1275 ) -> Option<Result<(BTreeSet<Backend>, HashMap<u64, bool>)>> {
1276 let discovery = self.get(upstream_id)?;
1277 Some(discovery.discover().await)
1278 }
1279
1280 pub fn remove(&self, upstream_id: &str) {
1282 self.discoveries.write().remove(upstream_id);
1283 }
1284
1285 pub fn count(&self) -> usize {
1287 self.discoveries.read().len()
1288 }
1289}
1290
1291impl Default for DiscoveryManager {
1292 fn default() -> Self {
1293 Self::new()
1294 }
1295}
1296
1297struct StaticWrapper(Box<StaticDiscovery>);
1299
1300#[async_trait]
1301impl ServiceDiscovery for StaticWrapper {
1302 async fn discover(&self) -> Result<(BTreeSet<Backend>, HashMap<u64, bool>)> {
1303 self.0.discover().await
1304 }
1305}
1306
1307unsafe impl Send for StaticWrapper {}
1309unsafe impl Sync for StaticWrapper {}
1310
1311#[cfg(test)]
1312mod tests {
1313 use super::*;
1314
1315 #[test]
1316 fn test_discovery_config_default() {
1317 let config = DiscoveryConfig::default();
1318 match config {
1319 DiscoveryConfig::Static { backends } => {
1320 assert_eq!(backends.len(), 1);
1321 assert_eq!(backends[0], "127.0.0.1:8080");
1322 }
1323 _ => panic!("Expected Static config"),
1324 }
1325 }
1326
1327 #[tokio::test]
1328 async fn test_discovery_manager() {
1329 let manager = DiscoveryManager::new();
1330
1331 manager
1333 .register(
1334 "test-upstream",
1335 DiscoveryConfig::Static {
1336 backends: vec!["127.0.0.1:8080".to_string(), "127.0.0.1:8081".to_string()],
1337 },
1338 )
1339 .unwrap();
1340
1341 assert_eq!(manager.count(), 1);
1342
1343 let result = manager.discover("test-upstream").await;
1345 assert!(result.is_some());
1346 let (backends, _) = result.unwrap().unwrap();
1347 assert_eq!(backends.len(), 2);
1348 }
1349
1350 #[test]
1351 fn test_dns_discovery_needs_refresh() {
1352 let discovery = DnsDiscovery::new(
1353 "localhost".to_string(),
1354 8080,
1355 Duration::from_secs(0), );
1357
1358 assert!(discovery.needs_refresh());
1360 }
1361
1362 #[test]
1363 fn test_consul_discovery_url_building() {
1364 let discovery = ConsulDiscovery::new(
1365 "http://localhost:8500".to_string(),
1366 "my-service".to_string(),
1367 Some("dc1".to_string()),
1368 true,
1369 Duration::from_secs(10),
1370 Some("production".to_string()),
1371 );
1372
1373 let url = discovery.build_url();
1374 assert!(url.starts_with("http://localhost:8500/v1/health/service/my-service"));
1375 assert!(url.contains("passing=true"));
1376 assert!(url.contains("dc=dc1"));
1377 assert!(url.contains("tag=production"));
1378 }
1379
1380 #[test]
1381 fn test_consul_discovery_url_minimal() {
1382 let discovery = ConsulDiscovery::new(
1383 "http://consul.local:8500".to_string(),
1384 "backend".to_string(),
1385 None,
1386 false,
1387 Duration::from_secs(30),
1388 None,
1389 );
1390
1391 let url = discovery.build_url();
1392 assert_eq!(url, "http://consul.local:8500/v1/health/service/backend");
1393 }
1394
1395 #[test]
1396 fn test_kubernetes_discovery_config() {
1397 let discovery = KubernetesDiscovery::new(
1398 "default".to_string(),
1399 "my-service".to_string(),
1400 Some("http".to_string()),
1401 Duration::from_secs(10),
1402 None,
1403 );
1404
1405 assert!(discovery.needs_refresh());
1407 }
1408
1409 #[test]
1410 fn test_parse_consul_response_empty() {
1411 let body = "[]";
1412 let backends = parse_consul_response(body, "test").unwrap();
1413 assert!(backends.is_empty());
1414 }
1415
1416 #[tokio::test]
1417 async fn test_discovery_manager_consul() {
1418 let manager = DiscoveryManager::new();
1419
1420 manager
1422 .register(
1423 "consul-upstream",
1424 DiscoveryConfig::Consul {
1425 address: "http://localhost:8500".to_string(),
1426 service: "my-service".to_string(),
1427 datacenter: Some("dc1".to_string()),
1428 only_passing: true,
1429 refresh_interval: Duration::from_secs(10),
1430 tag: None,
1431 },
1432 )
1433 .unwrap();
1434
1435 assert_eq!(manager.count(), 1);
1436 assert!(manager.get("consul-upstream").is_some());
1437 }
1438
1439 #[tokio::test]
1440 async fn test_discovery_manager_kubernetes() {
1441 let manager = DiscoveryManager::new();
1442
1443 manager
1445 .register(
1446 "k8s-upstream",
1447 DiscoveryConfig::Kubernetes {
1448 namespace: "production".to_string(),
1449 service: "api-server".to_string(),
1450 port_name: Some("http".to_string()),
1451 refresh_interval: Duration::from_secs(15),
1452 kubeconfig: None,
1453 },
1454 )
1455 .unwrap();
1456
1457 assert_eq!(manager.count(), 1);
1458 assert!(manager.get("k8s-upstream").is_some());
1459 }
1460
1461 #[test]
1466 fn test_file_discovery_parse_backend_line_simple() {
1467 let (address, weight) = FileDiscovery::parse_backend_line("127.0.0.1:8080", 1).unwrap();
1468 assert_eq!(address, "127.0.0.1:8080");
1469 assert_eq!(weight, 1);
1470 }
1471
1472 #[test]
1473 fn test_file_discovery_parse_backend_line_with_weight() {
1474 let (address, weight) =
1475 FileDiscovery::parse_backend_line("10.0.0.1:8080 weight=5", 1).unwrap();
1476 assert_eq!(address, "10.0.0.1:8080");
1477 assert_eq!(weight, 5);
1478 }
1479
1480 #[test]
1481 fn test_file_discovery_parse_backend_line_hostname() {
1482 let (address, weight) =
1483 FileDiscovery::parse_backend_line("backend.example.com:443 weight=2", 1).unwrap();
1484 assert_eq!(address, "backend.example.com:443");
1485 assert_eq!(weight, 2);
1486 }
1487
1488 #[test]
1489 fn test_file_discovery_needs_check() {
1490 let discovery = FileDiscovery::new(
1491 "/nonexistent/path.txt".to_string(),
1492 Duration::from_secs(0), );
1494
1495 assert!(discovery.needs_check());
1497 }
1498
1499 #[tokio::test]
1500 async fn test_file_discovery_with_temp_file() {
1501 use std::io::Write;
1502
1503 let temp_dir = tempfile::tempdir().unwrap();
1505 let file_path = temp_dir.path().join("backends.txt");
1506
1507 {
1508 let mut file = std::fs::File::create(&file_path).unwrap();
1509 writeln!(file, "# Backend servers").unwrap();
1510 writeln!(file, "127.0.0.1:8080").unwrap();
1511 writeln!(file, "127.0.0.1:8081 weight=2").unwrap();
1512 writeln!(file, "").unwrap(); writeln!(file, "127.0.0.1:8082 weight=3").unwrap();
1514 }
1515
1516 let discovery = FileDiscovery::new(
1517 file_path.to_string_lossy().to_string(),
1518 Duration::from_secs(1),
1519 );
1520
1521 let (backends, _) = discovery.discover().await.unwrap();
1523
1524 assert_eq!(backends.len(), 3);
1525
1526 let weights: Vec<usize> = backends.iter().map(|b| b.weight).collect();
1528 assert!(weights.contains(&1)); assert!(weights.contains(&2));
1530 assert!(weights.contains(&3));
1531 }
1532
1533 #[tokio::test]
1534 async fn test_file_discovery_missing_file_uses_cache() {
1535 use std::io::Write;
1536
1537 let temp_dir = tempfile::tempdir().unwrap();
1539 let file_path = temp_dir.path().join("backends.txt");
1540
1541 {
1542 let mut file = std::fs::File::create(&file_path).unwrap();
1543 writeln!(file, "127.0.0.1:8080").unwrap();
1544 }
1545
1546 let discovery = FileDiscovery::new(
1547 file_path.to_string_lossy().to_string(),
1548 Duration::from_secs(0), );
1550
1551 let (backends, _) = discovery.discover().await.unwrap();
1553 assert_eq!(backends.len(), 1);
1554
1555 std::fs::remove_file(&file_path).unwrap();
1557
1558 std::thread::sleep(Duration::from_millis(10));
1560
1561 let (backends, _) = discovery.discover().await.unwrap();
1563 assert_eq!(backends.len(), 1);
1564 }
1565
1566 #[tokio::test]
1567 async fn test_file_discovery_hot_reload() {
1568 use std::io::Write;
1569
1570 let temp_dir = tempfile::tempdir().unwrap();
1572 let file_path = temp_dir.path().join("backends.txt");
1573
1574 {
1575 let mut file = std::fs::File::create(&file_path).unwrap();
1576 writeln!(file, "127.0.0.1:8080").unwrap();
1577 }
1578
1579 let discovery = FileDiscovery::new(
1580 file_path.to_string_lossy().to_string(),
1581 Duration::from_millis(10), );
1583
1584 let (backends, _) = discovery.discover().await.unwrap();
1586 assert_eq!(backends.len(), 1);
1587
1588 std::thread::sleep(Duration::from_millis(50));
1590
1591 {
1593 let mut file = std::fs::File::create(&file_path).unwrap();
1594 writeln!(file, "127.0.0.1:8080").unwrap();
1595 writeln!(file, "127.0.0.1:8081").unwrap();
1596 writeln!(file, "127.0.0.1:8082").unwrap();
1597 }
1598
1599 let (backends, _) = discovery.discover().await.unwrap();
1601 assert_eq!(backends.len(), 3);
1602 }
1603
1604 #[tokio::test]
1605 async fn test_discovery_manager_file() {
1606 use std::io::Write;
1607
1608 let temp_dir = tempfile::tempdir().unwrap();
1610 let file_path = temp_dir.path().join("backends.txt");
1611
1612 {
1613 let mut file = std::fs::File::create(&file_path).unwrap();
1614 writeln!(file, "127.0.0.1:8080").unwrap();
1615 writeln!(file, "127.0.0.1:8081").unwrap();
1616 }
1617
1618 let manager = DiscoveryManager::new();
1619
1620 manager
1622 .register(
1623 "file-upstream",
1624 DiscoveryConfig::File {
1625 path: file_path.to_string_lossy().to_string(),
1626 watch_interval: Duration::from_secs(5),
1627 },
1628 )
1629 .unwrap();
1630
1631 assert_eq!(manager.count(), 1);
1632 assert!(manager.get("file-upstream").is_some());
1633
1634 let result = manager.discover("file-upstream").await;
1636 assert!(result.is_some());
1637 let (backends, _) = result.unwrap().unwrap();
1638 assert_eq!(backends.len(), 2);
1639 }
1640}