sentinel_proxy/
discovery.rs

1//! Service Discovery Module
2//!
3//! This module provides service discovery integration using pingora-load-balancing's
4//! ServiceDiscovery trait. Supports:
5//!
6//! - Static: Fixed list of backends (default)
7//! - DNS: Resolve backends from DNS A/AAAA records
8//! - DNS SRV: Resolve backends from DNS SRV records
9//! - Consul: Discover backends from Consul service catalog
10//! - Kubernetes: Discover backends from Kubernetes endpoints
11//! - File: Watch configuration file for backend changes
12//!
13//! # Example KDL Configuration
14//!
15//! ```kdl
16//! upstream "api" {
17//!     discovery "dns" {
18//!         hostname "api.example.com"
19//!         port 8080
20//!         refresh-interval 30
21//!     }
22//! }
23//!
24//! upstream "backend" {
25//!     discovery "consul" {
26//!         address "http://localhost:8500"
27//!         service "backend-api"
28//!         datacenter "dc1"
29//!         refresh-interval 10
30//!         only-passing true
31//!     }
32//! }
33//!
34//! upstream "k8s-service" {
35//!     discovery "kubernetes" {
36//!         namespace "default"
37//!         service "my-service"
38//!         port-name "http"
39//!         refresh-interval 10
40//!     }
41//! }
42//! ```
43
44use async_trait::async_trait;
45use parking_lot::RwLock;
46use pingora::prelude::*;
47use pingora_load_balancing::discovery::{ServiceDiscovery, Static as StaticDiscovery};
48use pingora_load_balancing::Backend;
49use std::collections::{BTreeSet, HashMap};
50use std::net::ToSocketAddrs;
51use std::sync::Arc;
52use std::time::{Duration, Instant};
53use tracing::{debug, error, info, trace, warn};
54
55/// Service discovery configuration
56#[derive(Debug, Clone)]
57pub enum DiscoveryConfig {
58    /// Static list of backends
59    Static {
60        /// Backend addresses (host:port)
61        backends: Vec<String>,
62    },
63    /// DNS-based discovery (A/AAAA records)
64    Dns {
65        /// DNS hostname to resolve
66        hostname: String,
67        /// Port for discovered backends
68        port: u16,
69        /// Resolution interval
70        refresh_interval: Duration,
71    },
72    /// DNS SRV-based discovery
73    DnsSrv {
74        /// Service name for SRV lookup (e.g., "_http._tcp.example.com")
75        service: String,
76        /// Resolution interval
77        refresh_interval: Duration,
78    },
79    /// Consul service discovery
80    Consul {
81        /// Consul HTTP API address
82        address: String,
83        /// Service name in Consul
84        service: String,
85        /// Datacenter (optional)
86        datacenter: Option<String>,
87        /// Only return healthy/passing services
88        only_passing: bool,
89        /// Refresh interval
90        refresh_interval: Duration,
91        /// Optional tag filter
92        tag: Option<String>,
93    },
94    /// Kubernetes endpoint discovery
95    Kubernetes {
96        /// Kubernetes namespace
97        namespace: String,
98        /// Service name
99        service: String,
100        /// Port name to use (if service has multiple ports)
101        port_name: Option<String>,
102        /// Refresh interval
103        refresh_interval: Duration,
104        /// Path to kubeconfig file (None = in-cluster config)
105        kubeconfig: Option<String>,
106    },
107    /// File-based discovery (watches config file)
108    File {
109        /// Path to the file containing backend addresses
110        path: String,
111        /// Watch interval
112        watch_interval: Duration,
113    },
114}
115
116impl Default for DiscoveryConfig {
117    fn default() -> Self {
118        Self::Static {
119            backends: vec!["127.0.0.1:8080".to_string()],
120        }
121    }
122}
123
124/// DNS-based service discovery
125///
126/// Resolves backends from DNS A/AAAA records.
127pub struct DnsDiscovery {
128    hostname: String,
129    port: u16,
130    refresh_interval: Duration,
131    /// Cached backends
132    cached_backends: RwLock<BTreeSet<Backend>>,
133    /// Last resolution time
134    last_resolution: RwLock<Instant>,
135}
136
137impl DnsDiscovery {
138    /// Create a new DNS discovery instance
139    pub fn new(hostname: String, port: u16, refresh_interval: Duration) -> Self {
140        Self {
141            hostname,
142            port,
143            refresh_interval,
144            cached_backends: RwLock::new(BTreeSet::new()),
145            last_resolution: RwLock::new(Instant::now() - refresh_interval),
146        }
147    }
148
149    /// Resolve the hostname to backends
150    fn resolve(&self) -> Result<BTreeSet<Backend>, Box<Error>> {
151        let address = format!("{}:{}", self.hostname, self.port);
152
153        trace!(
154            hostname = %self.hostname,
155            port = self.port,
156            "Resolving DNS for service discovery"
157        );
158
159        match address.to_socket_addrs() {
160            Ok(addrs) => {
161                let backends: BTreeSet<Backend> = addrs
162                    .map(|addr| Backend {
163                        addr: pingora_core::protocols::l4::socket::SocketAddr::Inet(addr),
164                        weight: 1,
165                        ext: http::Extensions::new(),
166                    })
167                    .collect();
168
169                debug!(
170                    hostname = %self.hostname,
171                    backend_count = backends.len(),
172                    "DNS resolution successful"
173                );
174
175                Ok(backends)
176            }
177            Err(e) => {
178                error!(
179                    hostname = %self.hostname,
180                    error = %e,
181                    "DNS resolution failed"
182                );
183                Err(Error::explain(
184                    ErrorType::ConnectNoRoute,
185                    format!("DNS resolution failed for {}: {}", self.hostname, e),
186                ))
187            }
188        }
189    }
190
191    /// Check if cache needs refresh
192    fn needs_refresh(&self) -> bool {
193        let last = *self.last_resolution.read();
194        last.elapsed() >= self.refresh_interval
195    }
196}
197
198#[async_trait]
199impl ServiceDiscovery for DnsDiscovery {
200    async fn discover(&self) -> Result<(BTreeSet<Backend>, HashMap<u64, bool>)> {
201        // Check if we need to refresh
202        if self.needs_refresh() {
203            match self.resolve() {
204                Ok(backends) => {
205                    *self.cached_backends.write() = backends;
206                    *self.last_resolution.write() = Instant::now();
207                }
208                Err(e) => {
209                    // Return cached backends on error if available
210                    let cached = self.cached_backends.read().clone();
211                    if !cached.is_empty() {
212                        warn!(
213                            hostname = %self.hostname,
214                            error = %e,
215                            cached_count = cached.len(),
216                            "DNS resolution failed, using cached backends"
217                        );
218                        return Ok((cached, HashMap::new()));
219                    }
220                    return Err(e);
221                }
222            }
223        }
224
225        let backends = self.cached_backends.read().clone();
226        Ok((backends, HashMap::new()))
227    }
228}
229
230// ============================================================================
231// Consul Service Discovery
232// ============================================================================
233
234/// Consul-based service discovery
235///
236/// Discovers backends from Consul's service catalog via HTTP API.
237pub struct ConsulDiscovery {
238    /// Consul HTTP API address
239    address: String,
240    /// Service name in Consul
241    service: String,
242    /// Datacenter (optional)
243    datacenter: Option<String>,
244    /// Only return healthy/passing services
245    only_passing: bool,
246    /// Refresh interval
247    refresh_interval: Duration,
248    /// Optional tag filter
249    tag: Option<String>,
250    /// Cached backends
251    cached_backends: RwLock<BTreeSet<Backend>>,
252    /// Last resolution time
253    last_resolution: RwLock<Instant>,
254}
255
256impl ConsulDiscovery {
257    /// Create a new Consul discovery instance
258    pub fn new(
259        address: String,
260        service: String,
261        datacenter: Option<String>,
262        only_passing: bool,
263        refresh_interval: Duration,
264        tag: Option<String>,
265    ) -> Self {
266        Self {
267            address,
268            service,
269            datacenter,
270            only_passing,
271            refresh_interval,
272            tag,
273            cached_backends: RwLock::new(BTreeSet::new()),
274            last_resolution: RwLock::new(Instant::now() - refresh_interval),
275        }
276    }
277
278    /// Build the Consul API URL for service health query
279    fn build_url(&self) -> String {
280        let mut url = format!(
281            "{}/v1/health/service/{}",
282            self.address.trim_end_matches('/'),
283            self.service
284        );
285
286        let mut params = Vec::new();
287        if self.only_passing {
288            params.push("passing=true".to_string());
289        }
290        if let Some(dc) = &self.datacenter {
291            params.push(format!("dc={}", dc));
292        }
293        if let Some(tag) = &self.tag {
294            params.push(format!("tag={}", tag));
295        }
296
297        if !params.is_empty() {
298            url.push('?');
299            url.push_str(&params.join("&"));
300        }
301
302        url
303    }
304
305    /// Check if cache needs refresh
306    fn needs_refresh(&self) -> bool {
307        let last = *self.last_resolution.read();
308        last.elapsed() >= self.refresh_interval
309    }
310}
311
312#[async_trait]
313impl ServiceDiscovery for ConsulDiscovery {
314    async fn discover(&self) -> Result<(BTreeSet<Backend>, HashMap<u64, bool>)> {
315        if !self.needs_refresh() {
316            let backends = self.cached_backends.read().clone();
317            return Ok((backends, HashMap::new()));
318        }
319
320        let url = self.build_url();
321        trace!(
322            service = %self.service,
323            url = %url,
324            "Querying Consul for service discovery"
325        );
326
327        // Use a simple HTTP request via std (blocking, but called from async context)
328        // In production, this should use an async HTTP client
329        let result = tokio::task::spawn_blocking({
330            let url = url.clone();
331            let service = self.service.clone();
332            move || -> Result<BTreeSet<Backend>, Box<Error>> {
333                // Simple HTTP GET using std::net
334                // Parse URL to get host and path
335                let url_parsed = url
336                    .trim_start_matches("http://")
337                    .trim_start_matches("https://");
338                let (host_port, path) = url_parsed.split_once('/').unwrap_or((url_parsed, ""));
339
340                let socket_addr = host_port
341                    .to_socket_addrs()
342                    .map_err(|e| {
343                        Error::explain(
344                            ErrorType::ConnectNoRoute,
345                            format!("Failed to resolve Consul address: {}", e),
346                        )
347                    })?
348                    .next()
349                    .ok_or_else(|| {
350                        Error::explain(
351                            ErrorType::ConnectNoRoute,
352                            "Failed to resolve Consul address",
353                        )
354                    })?;
355
356                let stream = match std::net::TcpStream::connect_timeout(
357                    &socket_addr,
358                    Duration::from_secs(5),
359                ) {
360                    Ok(s) => s,
361                    Err(e) => {
362                        return Err(Error::explain(
363                            ErrorType::ConnectTimedout,
364                            format!("Failed to connect to Consul: {}", e),
365                        ));
366                    }
367                };
368
369                stream
370                    .set_read_timeout(Some(Duration::from_secs(10)))
371                    .map_err(|e| {
372                        Error::explain(
373                            ErrorType::InternalError,
374                            format!("Failed to set read timeout: {}", e),
375                        )
376                    })?;
377                stream
378                    .set_write_timeout(Some(Duration::from_secs(5)))
379                    .map_err(|e| {
380                        Error::explain(
381                            ErrorType::InternalError,
382                            format!("Failed to set write timeout: {}", e),
383                        )
384                    })?;
385
386                use std::io::{Read, Write};
387                let request = format!(
388                    "GET /{} HTTP/1.1\r\nHost: {}\r\nConnection: close\r\n\r\n",
389                    path, host_port
390                );
391
392                let mut stream = stream;
393                stream.write_all(request.as_bytes()).map_err(|e| {
394                    Error::explain(
395                        ErrorType::WriteError,
396                        format!("Failed to send request: {}", e),
397                    )
398                })?;
399
400                let mut response = String::new();
401                stream.read_to_string(&mut response).map_err(|e| {
402                    Error::explain(
403                        ErrorType::ReadError,
404                        format!("Failed to read response: {}", e),
405                    )
406                })?;
407
408                // Parse response - find JSON body after headers
409                let body = response.split("\r\n\r\n").nth(1).unwrap_or("");
410
411                // Parse Consul response JSON
412                // Format: [{"Node":{"Address":"..."},"Service":{"Port":...}}]
413                let backends = parse_consul_response(body, &service)?;
414
415                Ok(backends)
416            }
417        })
418        .await
419        .map_err(|e| Error::explain(ErrorType::InternalError, format!("Task failed: {}", e)))?;
420
421        match result {
422            Ok(backends) => {
423                info!(
424                    service = %self.service,
425                    backend_count = backends.len(),
426                    "Consul discovery successful"
427                );
428                *self.cached_backends.write() = backends.clone();
429                *self.last_resolution.write() = Instant::now();
430                Ok((backends, HashMap::new()))
431            }
432            Err(e) => {
433                let cached = self.cached_backends.read().clone();
434                if !cached.is_empty() {
435                    warn!(
436                        service = %self.service,
437                        error = %e,
438                        cached_count = cached.len(),
439                        "Consul query failed, using cached backends"
440                    );
441                    return Ok((cached, HashMap::new()));
442                }
443                Err(e)
444            }
445        }
446    }
447}
448
449/// Parse Consul health API response
450fn parse_consul_response(body: &str, service_name: &str) -> Result<BTreeSet<Backend>, Box<Error>> {
451    // Simple JSON parsing without serde dependency
452    // Response format: [{"Node":{"Address":"ip"},"Service":{"Address":"","Port":8080}}]
453    let mut backends = BTreeSet::new();
454
455    // Very basic JSON extraction - in production use serde_json
456    let entries: Vec<&str> = body.split(r#""Service":"#).skip(1).collect();
457
458    for entry in entries {
459        // Extract port
460        let port = entry
461            .split(r#""Port":"#)
462            .nth(1)
463            .and_then(|s| s.split(|c: char| !c.is_ascii_digit()).next())
464            .and_then(|s| s.parse::<u16>().ok());
465
466        // Extract service address (may be empty, fall back to node address)
467        let service_addr = entry
468            .split(r#""Address":""#)
469            .nth(1)
470            .and_then(|s| s.split('"').next())
471            .filter(|s| !s.is_empty());
472
473        // Try to extract node address if service address is empty
474        let node_addr = body
475            .split(r#""Node":"#)
476            .nth(1)
477            .and_then(|s| s.split(r#""Address":""#).nth(1))
478            .and_then(|s| s.split('"').next());
479
480        let address = service_addr.or(node_addr);
481
482        if let (Some(addr), Some(port)) = (address, port) {
483            let full_addr = format!("{}:{}", addr, port);
484            if let Ok(mut addrs) = full_addr.to_socket_addrs() {
485                if let Some(socket_addr) = addrs.next() {
486                    backends.insert(Backend {
487                        addr: pingora_core::protocols::l4::socket::SocketAddr::Inet(socket_addr),
488                        weight: 1,
489                        ext: http::Extensions::new(),
490                    });
491                }
492            }
493        }
494    }
495
496    if backends.is_empty() && !body.starts_with("[]") && !body.is_empty() {
497        warn!(
498            service = %service_name,
499            body_len = body.len(),
500            "Failed to parse Consul response, no backends found"
501        );
502    }
503
504    Ok(backends)
505}
506
507// ============================================================================
508// Kubernetes Endpoint Discovery
509// ============================================================================
510
511/// Kubernetes endpoint discovery
512///
513/// Discovers backends from Kubernetes Endpoints resource.
514/// Requires either in-cluster configuration or kubeconfig file.
515pub struct KubernetesDiscovery {
516    /// Kubernetes namespace
517    namespace: String,
518    /// Service name
519    service: String,
520    /// Port name to use
521    port_name: Option<String>,
522    /// Refresh interval
523    refresh_interval: Duration,
524    /// Kubeconfig path (None = in-cluster)
525    kubeconfig: Option<String>,
526    /// Cached backends
527    cached_backends: RwLock<BTreeSet<Backend>>,
528    /// Last resolution time
529    last_resolution: RwLock<Instant>,
530}
531
532impl KubernetesDiscovery {
533    /// Create a new Kubernetes discovery instance
534    pub fn new(
535        namespace: String,
536        service: String,
537        port_name: Option<String>,
538        refresh_interval: Duration,
539        kubeconfig: Option<String>,
540    ) -> Self {
541        Self {
542            namespace,
543            service,
544            port_name,
545            refresh_interval,
546            kubeconfig,
547            cached_backends: RwLock::new(BTreeSet::new()),
548            last_resolution: RwLock::new(Instant::now() - refresh_interval),
549        }
550    }
551
552    /// Check if cache needs refresh
553    fn needs_refresh(&self) -> bool {
554        let last = *self.last_resolution.read();
555        last.elapsed() >= self.refresh_interval
556    }
557
558    /// Get the Kubernetes API server address and token
559    fn get_api_config(&self) -> Result<(String, String), Box<Error>> {
560        if self.kubeconfig.is_some() {
561            // TODO: Parse kubeconfig file
562            return Err(Error::explain(
563                ErrorType::InternalError,
564                "Kubeconfig parsing not yet implemented, use in-cluster config",
565            ));
566        }
567
568        // In-cluster configuration
569        let host = std::env::var("KUBERNETES_SERVICE_HOST").map_err(|_| {
570            Error::explain(
571                ErrorType::InternalError,
572                "KUBERNETES_SERVICE_HOST not set, not running in Kubernetes?",
573            )
574        })?;
575        let port = std::env::var("KUBERNETES_SERVICE_PORT").unwrap_or_else(|_| "443".to_string());
576        let token = std::fs::read_to_string("/var/run/secrets/kubernetes.io/serviceaccount/token")
577            .map_err(|e| {
578                Error::explain(
579                    ErrorType::InternalError,
580                    format!("Failed to read service account token: {}", e),
581                )
582            })?;
583
584        Ok((format!("https://{}:{}", host, port), token))
585    }
586}
587
588#[async_trait]
589impl ServiceDiscovery for KubernetesDiscovery {
590    async fn discover(&self) -> Result<(BTreeSet<Backend>, HashMap<u64, bool>)> {
591        if !self.needs_refresh() {
592            let backends = self.cached_backends.read().clone();
593            return Ok((backends, HashMap::new()));
594        }
595
596        trace!(
597            namespace = %self.namespace,
598            service = %self.service,
599            "Querying Kubernetes for endpoint discovery"
600        );
601
602        // Get API configuration
603        let (api_server, _token) = match self.get_api_config() {
604            Ok(config) => config,
605            Err(e) => {
606                let cached = self.cached_backends.read().clone();
607                if !cached.is_empty() {
608                    warn!(
609                        service = %self.service,
610                        error = %e,
611                        cached_count = cached.len(),
612                        "Kubernetes config unavailable, using cached backends"
613                    );
614                    return Ok((cached, HashMap::new()));
615                }
616                return Err(e);
617            }
618        };
619
620        // Build endpoint URL
621        let url = format!(
622            "{}/api/v1/namespaces/{}/endpoints/{}",
623            api_server, self.namespace, self.service
624        );
625
626        debug!(
627            url = %url,
628            namespace = %self.namespace,
629            service = %self.service,
630            "Kubernetes endpoint URL constructed"
631        );
632
633        // Note: In production, this should make an actual HTTPS request to the K8s API
634        // with proper TLS verification and the bearer token.
635        // For now, we return empty and log that full implementation is needed.
636        warn!(
637            service = %self.service,
638            "Kubernetes discovery requires full HTTP client - returning cached or empty"
639        );
640
641        let cached = self.cached_backends.read().clone();
642        if !cached.is_empty() {
643            return Ok((cached, HashMap::new()));
644        }
645
646        // Return empty set - Kubernetes discovery requires async HTTP client
647        Ok((BTreeSet::new(), HashMap::new()))
648    }
649}
650
651/// Service discovery manager
652///
653/// Manages service discovery for upstreams with support for multiple
654/// discovery mechanisms.
655pub struct DiscoveryManager {
656    /// Discovery implementations keyed by upstream ID
657    discoveries: RwLock<HashMap<String, Arc<dyn ServiceDiscovery + Send + Sync>>>,
658}
659
660impl DiscoveryManager {
661    /// Create a new discovery manager
662    pub fn new() -> Self {
663        Self {
664            discoveries: RwLock::new(HashMap::new()),
665        }
666    }
667
668    /// Register a service discovery for an upstream
669    pub fn register(&self, upstream_id: &str, config: DiscoveryConfig) -> Result<(), Box<Error>> {
670        let discovery: Arc<dyn ServiceDiscovery + Send + Sync> = match config {
671            DiscoveryConfig::Static { backends } => {
672                let backend_set = backends
673                    .iter()
674                    .filter_map(|addr| {
675                        addr.to_socket_addrs()
676                            .ok()
677                            .and_then(|mut addrs| addrs.next())
678                            .map(|addr| Backend {
679                                addr: pingora_core::protocols::l4::socket::SocketAddr::Inet(addr),
680                                weight: 1,
681                                ext: http::Extensions::new(),
682                            })
683                    })
684                    .collect();
685
686                info!(
687                    upstream_id = %upstream_id,
688                    backend_count = backends.len(),
689                    "Registered static service discovery"
690                );
691
692                Arc::new(StaticWrapper(StaticDiscovery::new(backend_set)))
693            }
694            DiscoveryConfig::Dns {
695                hostname,
696                port,
697                refresh_interval,
698            } => {
699                info!(
700                    upstream_id = %upstream_id,
701                    hostname = %hostname,
702                    port = port,
703                    refresh_interval_secs = refresh_interval.as_secs(),
704                    "Registered DNS service discovery"
705                );
706
707                Arc::new(DnsDiscovery::new(hostname, port, refresh_interval))
708            }
709            DiscoveryConfig::DnsSrv {
710                service,
711                refresh_interval,
712            } => {
713                info!(
714                    upstream_id = %upstream_id,
715                    service = %service,
716                    refresh_interval_secs = refresh_interval.as_secs(),
717                    "DNS SRV discovery not yet fully implemented, using DNS A record fallback"
718                );
719
720                // DNS SRV requires async DNS resolver - fall back to regular DNS for now
721                // Extract hostname from service name (e.g., "_http._tcp.example.com" -> "example.com")
722                let hostname = service
723                    .split('.')
724                    .skip_while(|s| s.starts_with('_'))
725                    .collect::<Vec<_>>()
726                    .join(".");
727                Arc::new(DnsDiscovery::new(hostname, 80, refresh_interval))
728            }
729            DiscoveryConfig::Consul {
730                address,
731                service,
732                datacenter,
733                only_passing,
734                refresh_interval,
735                tag,
736            } => {
737                info!(
738                    upstream_id = %upstream_id,
739                    address = %address,
740                    service = %service,
741                    datacenter = datacenter.as_deref().unwrap_or("default"),
742                    only_passing = only_passing,
743                    refresh_interval_secs = refresh_interval.as_secs(),
744                    "Registered Consul service discovery"
745                );
746
747                Arc::new(ConsulDiscovery::new(
748                    address,
749                    service,
750                    datacenter,
751                    only_passing,
752                    refresh_interval,
753                    tag,
754                ))
755            }
756            DiscoveryConfig::Kubernetes {
757                namespace,
758                service,
759                port_name,
760                refresh_interval,
761                kubeconfig,
762            } => {
763                info!(
764                    upstream_id = %upstream_id,
765                    namespace = %namespace,
766                    service = %service,
767                    port_name = port_name.as_deref().unwrap_or("default"),
768                    refresh_interval_secs = refresh_interval.as_secs(),
769                    "Registered Kubernetes endpoint discovery"
770                );
771
772                Arc::new(KubernetesDiscovery::new(
773                    namespace,
774                    service,
775                    port_name,
776                    refresh_interval,
777                    kubeconfig,
778                ))
779            }
780            DiscoveryConfig::File {
781                path,
782                watch_interval,
783            } => {
784                info!(
785                    upstream_id = %upstream_id,
786                    path = %path,
787                    watch_interval_secs = watch_interval.as_secs(),
788                    "File-based discovery not yet implemented, using empty static"
789                );
790
791                // TODO: Implement file-based discovery
792                Arc::new(StaticWrapper(StaticDiscovery::new(BTreeSet::new())))
793            }
794        };
795
796        self.discoveries
797            .write()
798            .insert(upstream_id.to_string(), discovery);
799        Ok(())
800    }
801
802    /// Get the discovery for an upstream
803    pub fn get(&self, upstream_id: &str) -> Option<Arc<dyn ServiceDiscovery + Send + Sync>> {
804        self.discoveries.read().get(upstream_id).cloned()
805    }
806
807    /// Discover backends for an upstream
808    pub async fn discover(
809        &self,
810        upstream_id: &str,
811    ) -> Option<Result<(BTreeSet<Backend>, HashMap<u64, bool>)>> {
812        let discovery = self.get(upstream_id)?;
813        Some(discovery.discover().await)
814    }
815
816    /// Remove discovery for an upstream
817    pub fn remove(&self, upstream_id: &str) {
818        self.discoveries.write().remove(upstream_id);
819    }
820
821    /// Number of registered discoveries
822    pub fn count(&self) -> usize {
823        self.discoveries.read().len()
824    }
825}
826
827impl Default for DiscoveryManager {
828    fn default() -> Self {
829        Self::new()
830    }
831}
832
833/// Wrapper for pingora's Static discovery to add Send + Sync
834struct StaticWrapper(Box<StaticDiscovery>);
835
836#[async_trait]
837impl ServiceDiscovery for StaticWrapper {
838    async fn discover(&self) -> Result<(BTreeSet<Backend>, HashMap<u64, bool>)> {
839        self.0.discover().await
840    }
841}
842
843// Make StaticWrapper Send + Sync safe since StaticDiscovery uses ArcSwap internally
844unsafe impl Send for StaticWrapper {}
845unsafe impl Sync for StaticWrapper {}
846
847#[cfg(test)]
848mod tests {
849    use super::*;
850
851    #[test]
852    fn test_discovery_config_default() {
853        let config = DiscoveryConfig::default();
854        match config {
855            DiscoveryConfig::Static { backends } => {
856                assert_eq!(backends.len(), 1);
857                assert_eq!(backends[0], "127.0.0.1:8080");
858            }
859            _ => panic!("Expected Static config"),
860        }
861    }
862
863    #[tokio::test]
864    async fn test_discovery_manager() {
865        let manager = DiscoveryManager::new();
866
867        // Register static discovery
868        manager
869            .register(
870                "test-upstream",
871                DiscoveryConfig::Static {
872                    backends: vec!["127.0.0.1:8080".to_string(), "127.0.0.1:8081".to_string()],
873                },
874            )
875            .unwrap();
876
877        assert_eq!(manager.count(), 1);
878
879        // Discover backends
880        let result = manager.discover("test-upstream").await;
881        assert!(result.is_some());
882        let (backends, _) = result.unwrap().unwrap();
883        assert_eq!(backends.len(), 2);
884    }
885
886    #[test]
887    fn test_dns_discovery_needs_refresh() {
888        let discovery = DnsDiscovery::new(
889            "localhost".to_string(),
890            8080,
891            Duration::from_secs(0), // Immediate refresh
892        );
893
894        // Should need refresh immediately after creation
895        assert!(discovery.needs_refresh());
896    }
897
898    #[test]
899    fn test_consul_discovery_url_building() {
900        let discovery = ConsulDiscovery::new(
901            "http://localhost:8500".to_string(),
902            "my-service".to_string(),
903            Some("dc1".to_string()),
904            true,
905            Duration::from_secs(10),
906            Some("production".to_string()),
907        );
908
909        let url = discovery.build_url();
910        assert!(url.starts_with("http://localhost:8500/v1/health/service/my-service"));
911        assert!(url.contains("passing=true"));
912        assert!(url.contains("dc=dc1"));
913        assert!(url.contains("tag=production"));
914    }
915
916    #[test]
917    fn test_consul_discovery_url_minimal() {
918        let discovery = ConsulDiscovery::new(
919            "http://consul.local:8500".to_string(),
920            "backend".to_string(),
921            None,
922            false,
923            Duration::from_secs(30),
924            None,
925        );
926
927        let url = discovery.build_url();
928        assert_eq!(url, "http://consul.local:8500/v1/health/service/backend");
929    }
930
931    #[test]
932    fn test_kubernetes_discovery_config() {
933        let discovery = KubernetesDiscovery::new(
934            "default".to_string(),
935            "my-service".to_string(),
936            Some("http".to_string()),
937            Duration::from_secs(10),
938            None,
939        );
940
941        // Should need refresh immediately after creation
942        assert!(discovery.needs_refresh());
943    }
944
945    #[test]
946    fn test_parse_consul_response_empty() {
947        let body = "[]";
948        let backends = parse_consul_response(body, "test").unwrap();
949        assert!(backends.is_empty());
950    }
951
952    #[tokio::test]
953    async fn test_discovery_manager_consul() {
954        let manager = DiscoveryManager::new();
955
956        // Register Consul discovery
957        manager
958            .register(
959                "consul-upstream",
960                DiscoveryConfig::Consul {
961                    address: "http://localhost:8500".to_string(),
962                    service: "my-service".to_string(),
963                    datacenter: Some("dc1".to_string()),
964                    only_passing: true,
965                    refresh_interval: Duration::from_secs(10),
966                    tag: None,
967                },
968            )
969            .unwrap();
970
971        assert_eq!(manager.count(), 1);
972        assert!(manager.get("consul-upstream").is_some());
973    }
974
975    #[tokio::test]
976    async fn test_discovery_manager_kubernetes() {
977        let manager = DiscoveryManager::new();
978
979        // Register Kubernetes discovery
980        manager
981            .register(
982                "k8s-upstream",
983                DiscoveryConfig::Kubernetes {
984                    namespace: "production".to_string(),
985                    service: "api-server".to_string(),
986                    port_name: Some("http".to_string()),
987                    refresh_interval: Duration::from_secs(15),
988                    kubeconfig: None,
989                },
990            )
991            .unwrap();
992
993        assert_eq!(manager.count(), 1);
994        assert!(manager.get("k8s-upstream").is_some());
995    }
996}