sentinel_proxy/upstream/
health.rs

1//! Active health checking using Pingora's HttpHealthCheck
2//!
3//! This module provides active health probing for upstream backends using
4//! Pingora's built-in health check infrastructure. It complements the passive
5//! health tracking in load balancers by periodically probing backends.
6
7use pingora_load_balancing::{
8    discovery::Static,
9    health_check::{HealthCheck as PingoraHealthCheck, HttpHealthCheck, TcpHealthCheck},
10    Backend, Backends,
11};
12use std::collections::BTreeSet;
13use std::sync::Arc;
14use std::time::Duration;
15use tokio::sync::RwLock;
16use tracing::{debug, info, trace, warn};
17
18use crate::grpc_health::GrpcHealthCheck;
19
20use sentinel_common::types::HealthCheckType;
21use sentinel_config::{HealthCheck as HealthCheckConfig, UpstreamConfig};
22
23/// Active health checker for an upstream pool
24///
25/// This wraps Pingora's `Backends` struct with health checking enabled.
26/// It runs periodic health probes and reports status back to the load balancer.
27pub struct ActiveHealthChecker {
28    /// Upstream ID
29    upstream_id: String,
30    /// Pingora backends with health checking
31    backends: Arc<Backends>,
32    /// Health check interval
33    interval: Duration,
34    /// Whether to run checks in parallel
35    parallel: bool,
36    /// Callback to notify load balancer of health changes
37    health_callback: Arc<RwLock<Option<HealthChangeCallback>>>,
38}
39
40/// Callback type for health status changes
41pub type HealthChangeCallback = Box<dyn Fn(&str, bool) + Send + Sync>;
42
43impl ActiveHealthChecker {
44    /// Create a new active health checker from upstream config
45    pub fn new(config: &UpstreamConfig) -> Option<Self> {
46        let health_config = config.health_check.as_ref()?;
47
48        info!(
49            upstream_id = %config.id,
50            check_type = ?health_config.check_type,
51            interval_secs = health_config.interval_secs,
52            "Creating active health checker"
53        );
54
55        // Create backends from targets
56        let mut backend_set = BTreeSet::new();
57        for target in &config.targets {
58            match Backend::new_with_weight(&target.address, target.weight as usize) {
59                Ok(backend) => {
60                    debug!(
61                        upstream_id = %config.id,
62                        target = %target.address,
63                        weight = target.weight,
64                        "Added backend for health checking"
65                    );
66                    backend_set.insert(backend);
67                }
68                Err(e) => {
69                    warn!(
70                        upstream_id = %config.id,
71                        target = %target.address,
72                        error = %e,
73                        "Failed to create backend for health checking"
74                    );
75                }
76            }
77        }
78
79        if backend_set.is_empty() {
80            warn!(
81                upstream_id = %config.id,
82                "No backends created for health checking"
83            );
84            return None;
85        }
86
87        // Create static discovery (Static::new returns Box<Self>)
88        let discovery = Static::new(backend_set);
89        let mut backends = Backends::new(discovery);
90
91        // Create and configure health check
92        let health_check: Box<dyn PingoraHealthCheck + Send + Sync> =
93            Self::create_health_check(health_config, &config.id);
94
95        backends.set_health_check(health_check);
96
97        Some(Self {
98            upstream_id: config.id.clone(),
99            backends: Arc::new(backends),
100            interval: Duration::from_secs(health_config.interval_secs),
101            parallel: true,
102            health_callback: Arc::new(RwLock::new(None)),
103        })
104    }
105
106    /// Create the appropriate health check based on config
107    fn create_health_check(
108        config: &HealthCheckConfig,
109        upstream_id: &str,
110    ) -> Box<dyn PingoraHealthCheck + Send + Sync> {
111        match &config.check_type {
112            HealthCheckType::Http {
113                path,
114                expected_status,
115                host,
116            } => {
117                let hostname = host.as_deref().unwrap_or("localhost");
118                let mut hc = HttpHealthCheck::new(hostname, false);
119
120                // Configure thresholds
121                hc.consecutive_success = config.healthy_threshold as usize;
122                hc.consecutive_failure = config.unhealthy_threshold as usize;
123
124                // Configure request path
125                // Note: HttpHealthCheck sends GET to / by default
126                // We customize by modifying hc.req for non-root paths
127                if path != "/" {
128                    // Create custom request header for the health check path
129                    if let Ok(req) =
130                        pingora_http::RequestHeader::build("GET", path.as_bytes(), None)
131                    {
132                        hc.req = req;
133                    }
134                }
135
136                // Note: health_changed_callback requires implementing HealthObserve trait
137                // We use polling via run_health_check() and get_health_statuses() instead
138
139                debug!(
140                    upstream_id = %upstream_id,
141                    path = %path,
142                    expected_status = expected_status,
143                    host = hostname,
144                    consecutive_success = hc.consecutive_success,
145                    consecutive_failure = hc.consecutive_failure,
146                    "Created HTTP health check"
147                );
148
149                Box::new(hc)
150            }
151            HealthCheckType::Tcp => {
152                // TcpHealthCheck::new() returns Box<Self>
153                let mut hc = TcpHealthCheck::new();
154                hc.consecutive_success = config.healthy_threshold as usize;
155                hc.consecutive_failure = config.unhealthy_threshold as usize;
156
157                debug!(
158                    upstream_id = %upstream_id,
159                    consecutive_success = hc.consecutive_success,
160                    consecutive_failure = hc.consecutive_failure,
161                    "Created TCP health check"
162                );
163
164                hc
165            }
166            HealthCheckType::Grpc { service } => {
167                let timeout = Duration::from_secs(config.timeout_secs);
168                let mut hc = GrpcHealthCheck::new(service.clone(), timeout);
169                hc.consecutive_success = config.healthy_threshold as usize;
170                hc.consecutive_failure = config.unhealthy_threshold as usize;
171
172                info!(
173                    upstream_id = %upstream_id,
174                    service = %service,
175                    timeout_secs = config.timeout_secs,
176                    consecutive_success = hc.consecutive_success,
177                    consecutive_failure = hc.consecutive_failure,
178                    "Created gRPC health check"
179                );
180
181                Box::new(hc)
182            }
183        }
184    }
185
186    /// Set callback for health status changes
187    pub async fn set_health_callback(&self, callback: HealthChangeCallback) {
188        *self.health_callback.write().await = Some(callback);
189    }
190
191    /// Run a single health check cycle
192    pub async fn run_health_check(&self) {
193        trace!(
194            upstream_id = %self.upstream_id,
195            parallel = self.parallel,
196            "Running health check cycle"
197        );
198
199        self.backends.run_health_check(self.parallel).await;
200    }
201
202    /// Check if a specific backend is healthy
203    pub fn is_backend_healthy(&self, address: &str) -> bool {
204        let backends = self.backends.get_backend();
205        for backend in backends.iter() {
206            if backend.addr.to_string() == address {
207                return self.backends.ready(backend);
208            }
209        }
210        // Unknown backend, assume healthy
211        true
212    }
213
214    /// Get all backend health statuses
215    pub fn get_health_statuses(&self) -> Vec<(String, bool)> {
216        let backends = self.backends.get_backend();
217        backends
218            .iter()
219            .map(|b| {
220                let addr = b.addr.to_string();
221                let healthy = self.backends.ready(b);
222                (addr, healthy)
223            })
224            .collect()
225    }
226
227    /// Get the health check interval
228    pub fn interval(&self) -> Duration {
229        self.interval
230    }
231
232    /// Get the upstream ID
233    pub fn upstream_id(&self) -> &str {
234        &self.upstream_id
235    }
236}
237
238/// Health check runner that manages multiple upstream health checkers
239pub struct HealthCheckRunner {
240    /// Health checkers per upstream
241    checkers: Vec<ActiveHealthChecker>,
242    /// Whether the runner is active
243    running: Arc<RwLock<bool>>,
244}
245
246impl HealthCheckRunner {
247    /// Create a new health check runner
248    pub fn new() -> Self {
249        Self {
250            checkers: Vec::new(),
251            running: Arc::new(RwLock::new(false)),
252        }
253    }
254
255    /// Add a health checker for an upstream
256    pub fn add_checker(&mut self, checker: ActiveHealthChecker) {
257        info!(
258            upstream_id = %checker.upstream_id,
259            interval_secs = checker.interval.as_secs(),
260            "Added health checker to runner"
261        );
262        self.checkers.push(checker);
263    }
264
265    /// Get the number of health checkers
266    pub fn checker_count(&self) -> usize {
267        self.checkers.len()
268    }
269
270    /// Start the health check loop (runs until stopped)
271    pub async fn run(&self) {
272        if self.checkers.is_empty() {
273            info!("No health checkers configured, skipping health check loop");
274            return;
275        }
276
277        *self.running.write().await = true;
278
279        info!(
280            checker_count = self.checkers.len(),
281            "Starting health check runner"
282        );
283
284        // Find minimum interval
285        let min_interval = self
286            .checkers
287            .iter()
288            .map(|c| c.interval)
289            .min()
290            .unwrap_or(Duration::from_secs(10));
291
292        let mut interval = tokio::time::interval(min_interval);
293        interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
294
295        loop {
296            interval.tick().await;
297
298            if !*self.running.read().await {
299                info!("Health check runner stopped");
300                break;
301            }
302
303            // Run health checks for all upstreams
304            for checker in &self.checkers {
305                checker.run_health_check().await;
306
307                // Log current health statuses
308                let statuses = checker.get_health_statuses();
309                for (addr, healthy) in &statuses {
310                    trace!(
311                        upstream_id = %checker.upstream_id,
312                        backend = %addr,
313                        healthy = healthy,
314                        "Backend health status"
315                    );
316                }
317            }
318        }
319    }
320
321    /// Stop the health check loop
322    pub async fn stop(&self) {
323        info!("Stopping health check runner");
324        *self.running.write().await = false;
325    }
326
327    /// Get health status for a specific upstream and backend
328    pub fn get_health(&self, upstream_id: &str, address: &str) -> Option<bool> {
329        self.checkers
330            .iter()
331            .find(|c| c.upstream_id == upstream_id)
332            .map(|c| c.is_backend_healthy(address))
333    }
334
335    /// Get all health statuses for an upstream
336    pub fn get_upstream_health(&self, upstream_id: &str) -> Option<Vec<(String, bool)>> {
337        self.checkers
338            .iter()
339            .find(|c| c.upstream_id == upstream_id)
340            .map(|c| c.get_health_statuses())
341    }
342}
343
344impl Default for HealthCheckRunner {
345    fn default() -> Self {
346        Self::new()
347    }
348}
349
350#[cfg(test)]
351mod tests {
352    use super::*;
353    use sentinel_common::types::LoadBalancingAlgorithm;
354    use sentinel_config::{
355        ConnectionPoolConfig, HttpVersionConfig, UpstreamTarget, UpstreamTimeouts,
356    };
357    use std::collections::HashMap;
358
359    fn create_test_config() -> UpstreamConfig {
360        UpstreamConfig {
361            id: "test-upstream".to_string(),
362            targets: vec![UpstreamTarget {
363                address: "127.0.0.1:8081".to_string(),
364                weight: 1,
365                max_requests: None,
366                metadata: HashMap::new(),
367            }],
368            load_balancing: LoadBalancingAlgorithm::RoundRobin,
369            health_check: Some(HealthCheckConfig {
370                check_type: HealthCheckType::Http {
371                    path: "/health".to_string(),
372                    expected_status: 200,
373                    host: None,
374                },
375                interval_secs: 5,
376                timeout_secs: 2,
377                healthy_threshold: 2,
378                unhealthy_threshold: 3,
379            }),
380            connection_pool: ConnectionPoolConfig::default(),
381            timeouts: UpstreamTimeouts::default(),
382            tls: None,
383            http_version: HttpVersionConfig::default(),
384        }
385    }
386
387    #[test]
388    fn test_create_health_checker() {
389        let config = create_test_config();
390        let checker = ActiveHealthChecker::new(&config);
391        assert!(checker.is_some());
392
393        let checker = checker.unwrap();
394        assert_eq!(checker.upstream_id, "test-upstream");
395        assert_eq!(checker.interval, Duration::from_secs(5));
396    }
397
398    #[test]
399    fn test_no_health_check_config() {
400        let mut config = create_test_config();
401        config.health_check = None;
402
403        let checker = ActiveHealthChecker::new(&config);
404        assert!(checker.is_none());
405    }
406
407    #[test]
408    fn test_health_check_runner() {
409        let mut runner = HealthCheckRunner::new();
410        assert_eq!(runner.checker_count(), 0);
411
412        let config = create_test_config();
413        if let Some(checker) = ActiveHealthChecker::new(&config) {
414            runner.add_checker(checker);
415            assert_eq!(runner.checker_count(), 1);
416        }
417    }
418}