Skip to main content

sentinel_proxy/upstream/
health.rs

1//! Active health checking using Pingora's HttpHealthCheck
2//!
3//! This module provides active health probing for upstream backends using
4//! Pingora's built-in health check infrastructure. It complements the passive
5//! health tracking in load balancers by periodically probing backends.
6
7use pingora_load_balancing::{
8    discovery::Static,
9    health_check::{HealthCheck as PingoraHealthCheck, HttpHealthCheck, TcpHealthCheck},
10    Backend, Backends,
11};
12use std::collections::BTreeSet;
13use std::sync::Arc;
14use std::time::Duration;
15use tokio::sync::RwLock;
16use tracing::{debug, info, trace, warn};
17
18use crate::grpc_health::GrpcHealthCheck;
19use crate::upstream::inference_health::InferenceHealthCheck;
20
21use sentinel_common::types::HealthCheckType;
22use sentinel_config::{HealthCheck as HealthCheckConfig, UpstreamConfig};
23
24/// Active health checker for an upstream pool
25///
26/// This wraps Pingora's `Backends` struct with health checking enabled.
27/// It runs periodic health probes and reports status back to the load balancer.
28pub struct ActiveHealthChecker {
29    /// Upstream ID
30    upstream_id: String,
31    /// Pingora backends with health checking
32    backends: Arc<Backends>,
33    /// Health check interval
34    interval: Duration,
35    /// Whether to run checks in parallel
36    parallel: bool,
37    /// Callback to notify load balancer of health changes
38    health_callback: Arc<RwLock<Option<HealthChangeCallback>>>,
39}
40
41/// Callback type for health status changes
42pub type HealthChangeCallback = Box<dyn Fn(&str, bool) + Send + Sync>;
43
44impl ActiveHealthChecker {
45    /// Create a new active health checker from upstream config
46    pub fn new(config: &UpstreamConfig) -> Option<Self> {
47        let health_config = config.health_check.as_ref()?;
48
49        info!(
50            upstream_id = %config.id,
51            check_type = ?health_config.check_type,
52            interval_secs = health_config.interval_secs,
53            "Creating active health checker"
54        );
55
56        // Create backends from targets
57        let mut backend_set = BTreeSet::new();
58        for target in &config.targets {
59            match Backend::new_with_weight(&target.address, target.weight as usize) {
60                Ok(backend) => {
61                    debug!(
62                        upstream_id = %config.id,
63                        target = %target.address,
64                        weight = target.weight,
65                        "Added backend for health checking"
66                    );
67                    backend_set.insert(backend);
68                }
69                Err(e) => {
70                    warn!(
71                        upstream_id = %config.id,
72                        target = %target.address,
73                        error = %e,
74                        "Failed to create backend for health checking"
75                    );
76                }
77            }
78        }
79
80        if backend_set.is_empty() {
81            warn!(
82                upstream_id = %config.id,
83                "No backends created for health checking"
84            );
85            return None;
86        }
87
88        // Create static discovery (Static::new returns Box<Self>)
89        let discovery = Static::new(backend_set);
90        let mut backends = Backends::new(discovery);
91
92        // Create and configure health check
93        let health_check: Box<dyn PingoraHealthCheck + Send + Sync> =
94            Self::create_health_check(health_config, &config.id);
95
96        backends.set_health_check(health_check);
97
98        Some(Self {
99            upstream_id: config.id.clone(),
100            backends: Arc::new(backends),
101            interval: Duration::from_secs(health_config.interval_secs),
102            parallel: true,
103            health_callback: Arc::new(RwLock::new(None)),
104        })
105    }
106
107    /// Create the appropriate health check based on config
108    fn create_health_check(
109        config: &HealthCheckConfig,
110        upstream_id: &str,
111    ) -> Box<dyn PingoraHealthCheck + Send + Sync> {
112        match &config.check_type {
113            HealthCheckType::Http {
114                path,
115                expected_status,
116                host,
117            } => {
118                let hostname = host.as_deref().unwrap_or("localhost");
119                let mut hc = HttpHealthCheck::new(hostname, false);
120
121                // Configure thresholds
122                hc.consecutive_success = config.healthy_threshold as usize;
123                hc.consecutive_failure = config.unhealthy_threshold as usize;
124
125                // Configure request path
126                // Note: HttpHealthCheck sends GET to / by default
127                // We customize by modifying hc.req for non-root paths
128                if path != "/" {
129                    // Create custom request header for the health check path
130                    if let Ok(req) =
131                        pingora_http::RequestHeader::build("GET", path.as_bytes(), None)
132                    {
133                        hc.req = req;
134                    }
135                }
136
137                // Note: health_changed_callback requires implementing HealthObserve trait
138                // We use polling via run_health_check() and get_health_statuses() instead
139
140                debug!(
141                    upstream_id = %upstream_id,
142                    path = %path,
143                    expected_status = expected_status,
144                    host = hostname,
145                    consecutive_success = hc.consecutive_success,
146                    consecutive_failure = hc.consecutive_failure,
147                    "Created HTTP health check"
148                );
149
150                Box::new(hc)
151            }
152            HealthCheckType::Tcp => {
153                // TcpHealthCheck::new() returns Box<Self>
154                let mut hc = TcpHealthCheck::new();
155                hc.consecutive_success = config.healthy_threshold as usize;
156                hc.consecutive_failure = config.unhealthy_threshold as usize;
157
158                debug!(
159                    upstream_id = %upstream_id,
160                    consecutive_success = hc.consecutive_success,
161                    consecutive_failure = hc.consecutive_failure,
162                    "Created TCP health check"
163                );
164
165                hc
166            }
167            HealthCheckType::Grpc { service } => {
168                let timeout = Duration::from_secs(config.timeout_secs);
169                let mut hc = GrpcHealthCheck::new(service.clone(), timeout);
170                hc.consecutive_success = config.healthy_threshold as usize;
171                hc.consecutive_failure = config.unhealthy_threshold as usize;
172
173                info!(
174                    upstream_id = %upstream_id,
175                    service = %service,
176                    timeout_secs = config.timeout_secs,
177                    consecutive_success = hc.consecutive_success,
178                    consecutive_failure = hc.consecutive_failure,
179                    "Created gRPC health check"
180                );
181
182                Box::new(hc)
183            }
184            HealthCheckType::Inference {
185                endpoint,
186                expected_models,
187                readiness: _,
188            } => {
189                // Inference health check that verifies expected models are available
190                let timeout = Duration::from_secs(config.timeout_secs);
191                let mut hc =
192                    InferenceHealthCheck::new(endpoint.clone(), expected_models.clone(), timeout);
193                hc.consecutive_success = config.healthy_threshold as usize;
194                hc.consecutive_failure = config.unhealthy_threshold as usize;
195
196                info!(
197                    upstream_id = %upstream_id,
198                    endpoint = %endpoint,
199                    expected_models = ?expected_models,
200                    timeout_secs = config.timeout_secs,
201                    consecutive_success = hc.consecutive_success,
202                    consecutive_failure = hc.consecutive_failure,
203                    "Created inference health check with model verification"
204                );
205
206                Box::new(hc)
207            }
208        }
209    }
210
211    /// Set callback for health status changes
212    pub async fn set_health_callback(&self, callback: HealthChangeCallback) {
213        *self.health_callback.write().await = Some(callback);
214    }
215
216    /// Run a single health check cycle
217    pub async fn run_health_check(&self) {
218        trace!(
219            upstream_id = %self.upstream_id,
220            parallel = self.parallel,
221            "Running health check cycle"
222        );
223
224        self.backends.run_health_check(self.parallel).await;
225    }
226
227    /// Check if a specific backend is healthy
228    pub fn is_backend_healthy(&self, address: &str) -> bool {
229        let backends = self.backends.get_backend();
230        for backend in backends.iter() {
231            if backend.addr.to_string() == address {
232                return self.backends.ready(backend);
233            }
234        }
235        // Unknown backend, assume healthy
236        true
237    }
238
239    /// Get all backend health statuses
240    pub fn get_health_statuses(&self) -> Vec<(String, bool)> {
241        let backends = self.backends.get_backend();
242        backends
243            .iter()
244            .map(|b| {
245                let addr = b.addr.to_string();
246                let healthy = self.backends.ready(b);
247                (addr, healthy)
248            })
249            .collect()
250    }
251
252    /// Get the health check interval
253    pub fn interval(&self) -> Duration {
254        self.interval
255    }
256
257    /// Get the upstream ID
258    pub fn upstream_id(&self) -> &str {
259        &self.upstream_id
260    }
261}
262
263/// Health check runner that manages multiple upstream health checkers
264pub struct HealthCheckRunner {
265    /// Health checkers per upstream
266    checkers: Vec<ActiveHealthChecker>,
267    /// Whether the runner is active
268    running: Arc<RwLock<bool>>,
269}
270
271impl HealthCheckRunner {
272    /// Create a new health check runner
273    pub fn new() -> Self {
274        Self {
275            checkers: Vec::new(),
276            running: Arc::new(RwLock::new(false)),
277        }
278    }
279
280    /// Add a health checker for an upstream
281    pub fn add_checker(&mut self, checker: ActiveHealthChecker) {
282        info!(
283            upstream_id = %checker.upstream_id,
284            interval_secs = checker.interval.as_secs(),
285            "Added health checker to runner"
286        );
287        self.checkers.push(checker);
288    }
289
290    /// Get the number of health checkers
291    pub fn checker_count(&self) -> usize {
292        self.checkers.len()
293    }
294
295    /// Start the health check loop (runs until stopped)
296    pub async fn run(&self) {
297        if self.checkers.is_empty() {
298            info!("No health checkers configured, skipping health check loop");
299            return;
300        }
301
302        *self.running.write().await = true;
303
304        info!(
305            checker_count = self.checkers.len(),
306            "Starting health check runner"
307        );
308
309        // Find minimum interval
310        let min_interval = self
311            .checkers
312            .iter()
313            .map(|c| c.interval)
314            .min()
315            .unwrap_or(Duration::from_secs(10));
316
317        let mut interval = tokio::time::interval(min_interval);
318        interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
319
320        loop {
321            interval.tick().await;
322
323            if !*self.running.read().await {
324                info!("Health check runner stopped");
325                break;
326            }
327
328            // Run health checks for all upstreams
329            for checker in &self.checkers {
330                checker.run_health_check().await;
331
332                // Log current health statuses
333                let statuses = checker.get_health_statuses();
334                for (addr, healthy) in &statuses {
335                    trace!(
336                        upstream_id = %checker.upstream_id,
337                        backend = %addr,
338                        healthy = healthy,
339                        "Backend health status"
340                    );
341                }
342            }
343        }
344    }
345
346    /// Stop the health check loop
347    pub async fn stop(&self) {
348        info!("Stopping health check runner");
349        *self.running.write().await = false;
350    }
351
352    /// Get health status for a specific upstream and backend
353    pub fn get_health(&self, upstream_id: &str, address: &str) -> Option<bool> {
354        self.checkers
355            .iter()
356            .find(|c| c.upstream_id == upstream_id)
357            .map(|c| c.is_backend_healthy(address))
358    }
359
360    /// Get all health statuses for an upstream
361    pub fn get_upstream_health(&self, upstream_id: &str) -> Option<Vec<(String, bool)>> {
362        self.checkers
363            .iter()
364            .find(|c| c.upstream_id == upstream_id)
365            .map(|c| c.get_health_statuses())
366    }
367}
368
369impl Default for HealthCheckRunner {
370    fn default() -> Self {
371        Self::new()
372    }
373}
374
375#[cfg(test)]
376mod tests {
377    use super::*;
378    use sentinel_common::types::LoadBalancingAlgorithm;
379    use sentinel_config::{
380        ConnectionPoolConfig, HttpVersionConfig, UpstreamTarget, UpstreamTimeouts,
381    };
382    use std::collections::HashMap;
383    use std::sync::Once;
384
385    static INIT: Once = Once::new();
386
387    fn init_crypto_provider() {
388        INIT.call_once(|| {
389            let _ = rustls::crypto::aws_lc_rs::default_provider().install_default();
390        });
391    }
392
393    fn create_test_config() -> UpstreamConfig {
394        UpstreamConfig {
395            id: "test-upstream".to_string(),
396            targets: vec![UpstreamTarget {
397                address: "127.0.0.1:8081".to_string(),
398                weight: 1,
399                max_requests: None,
400                metadata: HashMap::new(),
401            }],
402            load_balancing: LoadBalancingAlgorithm::RoundRobin,
403            sticky_session: None,
404            health_check: Some(HealthCheckConfig {
405                check_type: HealthCheckType::Http {
406                    path: "/health".to_string(),
407                    expected_status: 200,
408                    host: None,
409                },
410                interval_secs: 5,
411                timeout_secs: 2,
412                healthy_threshold: 2,
413                unhealthy_threshold: 3,
414            }),
415            connection_pool: ConnectionPoolConfig::default(),
416            timeouts: UpstreamTimeouts::default(),
417            tls: None,
418            http_version: HttpVersionConfig::default(),
419        }
420    }
421
422    #[test]
423    fn test_create_health_checker() {
424        init_crypto_provider();
425        let config = create_test_config();
426        let checker = ActiveHealthChecker::new(&config);
427        assert!(checker.is_some());
428
429        let checker = checker.unwrap();
430        assert_eq!(checker.upstream_id, "test-upstream");
431        assert_eq!(checker.interval, Duration::from_secs(5));
432    }
433
434    #[test]
435    fn test_no_health_check_config() {
436        let mut config = create_test_config();
437        config.health_check = None;
438
439        let checker = ActiveHealthChecker::new(&config);
440        assert!(checker.is_none());
441    }
442
443    #[test]
444    fn test_health_check_runner() {
445        init_crypto_provider();
446        let mut runner = HealthCheckRunner::new();
447        assert_eq!(runner.checker_count(), 0);
448
449        let config = create_test_config();
450        if let Some(checker) = ActiveHealthChecker::new(&config) {
451            runner.add_checker(checker);
452            assert_eq!(runner.checker_count(), 1);
453        }
454    }
455}