1use pingora_load_balancing::{
8 discovery::Static,
9 health_check::{HealthCheck as PingoraHealthCheck, HttpHealthCheck, TcpHealthCheck},
10 Backend, Backends,
11};
12use std::collections::BTreeSet;
13use std::sync::Arc;
14use std::time::Duration;
15use tokio::sync::RwLock;
16use tracing::{debug, info, trace, warn};
17
18use crate::grpc_health::GrpcHealthCheck;
19use crate::upstream::inference_health::InferenceHealthCheck;
20
21use sentinel_common::types::HealthCheckType;
22use sentinel_config::{HealthCheck as HealthCheckConfig, UpstreamConfig};
23
24pub struct ActiveHealthChecker {
29 upstream_id: String,
31 backends: Arc<Backends>,
33 interval: Duration,
35 parallel: bool,
37 health_callback: Arc<RwLock<Option<HealthChangeCallback>>>,
39}
40
41pub type HealthChangeCallback = Box<dyn Fn(&str, bool) + Send + Sync>;
43
44impl ActiveHealthChecker {
45 pub fn new(config: &UpstreamConfig) -> Option<Self> {
47 let health_config = config.health_check.as_ref()?;
48
49 info!(
50 upstream_id = %config.id,
51 check_type = ?health_config.check_type,
52 interval_secs = health_config.interval_secs,
53 "Creating active health checker"
54 );
55
56 let mut backend_set = BTreeSet::new();
58 for target in &config.targets {
59 match Backend::new_with_weight(&target.address, target.weight as usize) {
60 Ok(backend) => {
61 debug!(
62 upstream_id = %config.id,
63 target = %target.address,
64 weight = target.weight,
65 "Added backend for health checking"
66 );
67 backend_set.insert(backend);
68 }
69 Err(e) => {
70 warn!(
71 upstream_id = %config.id,
72 target = %target.address,
73 error = %e,
74 "Failed to create backend for health checking"
75 );
76 }
77 }
78 }
79
80 if backend_set.is_empty() {
81 warn!(
82 upstream_id = %config.id,
83 "No backends created for health checking"
84 );
85 return None;
86 }
87
88 let discovery = Static::new(backend_set);
90 let mut backends = Backends::new(discovery);
91
92 let health_check: Box<dyn PingoraHealthCheck + Send + Sync> =
94 Self::create_health_check(health_config, &config.id);
95
96 backends.set_health_check(health_check);
97
98 Some(Self {
99 upstream_id: config.id.clone(),
100 backends: Arc::new(backends),
101 interval: Duration::from_secs(health_config.interval_secs),
102 parallel: true,
103 health_callback: Arc::new(RwLock::new(None)),
104 })
105 }
106
107 fn create_health_check(
109 config: &HealthCheckConfig,
110 upstream_id: &str,
111 ) -> Box<dyn PingoraHealthCheck + Send + Sync> {
112 match &config.check_type {
113 HealthCheckType::Http {
114 path,
115 expected_status,
116 host,
117 } => {
118 let hostname = host.as_deref().unwrap_or("localhost");
119 let mut hc = HttpHealthCheck::new(hostname, false);
120
121 hc.consecutive_success = config.healthy_threshold as usize;
123 hc.consecutive_failure = config.unhealthy_threshold as usize;
124
125 if path != "/" {
129 if let Ok(req) =
131 pingora_http::RequestHeader::build("GET", path.as_bytes(), None)
132 {
133 hc.req = req;
134 }
135 }
136
137 debug!(
141 upstream_id = %upstream_id,
142 path = %path,
143 expected_status = expected_status,
144 host = hostname,
145 consecutive_success = hc.consecutive_success,
146 consecutive_failure = hc.consecutive_failure,
147 "Created HTTP health check"
148 );
149
150 Box::new(hc)
151 }
152 HealthCheckType::Tcp => {
153 let mut hc = TcpHealthCheck::new();
155 hc.consecutive_success = config.healthy_threshold as usize;
156 hc.consecutive_failure = config.unhealthy_threshold as usize;
157
158 debug!(
159 upstream_id = %upstream_id,
160 consecutive_success = hc.consecutive_success,
161 consecutive_failure = hc.consecutive_failure,
162 "Created TCP health check"
163 );
164
165 hc
166 }
167 HealthCheckType::Grpc { service } => {
168 let timeout = Duration::from_secs(config.timeout_secs);
169 let mut hc = GrpcHealthCheck::new(service.clone(), timeout);
170 hc.consecutive_success = config.healthy_threshold as usize;
171 hc.consecutive_failure = config.unhealthy_threshold as usize;
172
173 info!(
174 upstream_id = %upstream_id,
175 service = %service,
176 timeout_secs = config.timeout_secs,
177 consecutive_success = hc.consecutive_success,
178 consecutive_failure = hc.consecutive_failure,
179 "Created gRPC health check"
180 );
181
182 Box::new(hc)
183 }
184 HealthCheckType::Inference {
185 endpoint,
186 expected_models,
187 readiness: _,
188 } => {
189 let timeout = Duration::from_secs(config.timeout_secs);
191 let mut hc =
192 InferenceHealthCheck::new(endpoint.clone(), expected_models.clone(), timeout);
193 hc.consecutive_success = config.healthy_threshold as usize;
194 hc.consecutive_failure = config.unhealthy_threshold as usize;
195
196 info!(
197 upstream_id = %upstream_id,
198 endpoint = %endpoint,
199 expected_models = ?expected_models,
200 timeout_secs = config.timeout_secs,
201 consecutive_success = hc.consecutive_success,
202 consecutive_failure = hc.consecutive_failure,
203 "Created inference health check with model verification"
204 );
205
206 Box::new(hc)
207 }
208 }
209 }
210
211 pub async fn set_health_callback(&self, callback: HealthChangeCallback) {
213 *self.health_callback.write().await = Some(callback);
214 }
215
216 pub async fn run_health_check(&self) {
218 trace!(
219 upstream_id = %self.upstream_id,
220 parallel = self.parallel,
221 "Running health check cycle"
222 );
223
224 self.backends.run_health_check(self.parallel).await;
225 }
226
227 pub fn is_backend_healthy(&self, address: &str) -> bool {
229 let backends = self.backends.get_backend();
230 for backend in backends.iter() {
231 if backend.addr.to_string() == address {
232 return self.backends.ready(backend);
233 }
234 }
235 true
237 }
238
239 pub fn get_health_statuses(&self) -> Vec<(String, bool)> {
241 let backends = self.backends.get_backend();
242 backends
243 .iter()
244 .map(|b| {
245 let addr = b.addr.to_string();
246 let healthy = self.backends.ready(b);
247 (addr, healthy)
248 })
249 .collect()
250 }
251
252 pub fn interval(&self) -> Duration {
254 self.interval
255 }
256
257 pub fn upstream_id(&self) -> &str {
259 &self.upstream_id
260 }
261}
262
263pub struct HealthCheckRunner {
265 checkers: Vec<ActiveHealthChecker>,
267 running: Arc<RwLock<bool>>,
269}
270
271impl HealthCheckRunner {
272 pub fn new() -> Self {
274 Self {
275 checkers: Vec::new(),
276 running: Arc::new(RwLock::new(false)),
277 }
278 }
279
280 pub fn add_checker(&mut self, checker: ActiveHealthChecker) {
282 info!(
283 upstream_id = %checker.upstream_id,
284 interval_secs = checker.interval.as_secs(),
285 "Added health checker to runner"
286 );
287 self.checkers.push(checker);
288 }
289
290 pub fn checker_count(&self) -> usize {
292 self.checkers.len()
293 }
294
295 pub async fn run(&self) {
297 if self.checkers.is_empty() {
298 info!("No health checkers configured, skipping health check loop");
299 return;
300 }
301
302 *self.running.write().await = true;
303
304 info!(
305 checker_count = self.checkers.len(),
306 "Starting health check runner"
307 );
308
309 let min_interval = self
311 .checkers
312 .iter()
313 .map(|c| c.interval)
314 .min()
315 .unwrap_or(Duration::from_secs(10));
316
317 let mut interval = tokio::time::interval(min_interval);
318 interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
319
320 loop {
321 interval.tick().await;
322
323 if !*self.running.read().await {
324 info!("Health check runner stopped");
325 break;
326 }
327
328 for checker in &self.checkers {
330 checker.run_health_check().await;
331
332 let statuses = checker.get_health_statuses();
334 for (addr, healthy) in &statuses {
335 trace!(
336 upstream_id = %checker.upstream_id,
337 backend = %addr,
338 healthy = healthy,
339 "Backend health status"
340 );
341 }
342 }
343 }
344 }
345
346 pub async fn stop(&self) {
348 info!("Stopping health check runner");
349 *self.running.write().await = false;
350 }
351
352 pub fn get_health(&self, upstream_id: &str, address: &str) -> Option<bool> {
354 self.checkers
355 .iter()
356 .find(|c| c.upstream_id == upstream_id)
357 .map(|c| c.is_backend_healthy(address))
358 }
359
360 pub fn get_upstream_health(&self, upstream_id: &str) -> Option<Vec<(String, bool)>> {
362 self.checkers
363 .iter()
364 .find(|c| c.upstream_id == upstream_id)
365 .map(|c| c.get_health_statuses())
366 }
367}
368
369impl Default for HealthCheckRunner {
370 fn default() -> Self {
371 Self::new()
372 }
373}
374
375#[cfg(test)]
376mod tests {
377 use super::*;
378 use sentinel_common::types::LoadBalancingAlgorithm;
379 use sentinel_config::{
380 ConnectionPoolConfig, HttpVersionConfig, UpstreamTarget, UpstreamTimeouts,
381 };
382 use std::collections::HashMap;
383 use std::sync::Once;
384
385 static INIT: Once = Once::new();
386
387 fn init_crypto_provider() {
388 INIT.call_once(|| {
389 let _ = rustls::crypto::aws_lc_rs::default_provider().install_default();
390 });
391 }
392
393 fn create_test_config() -> UpstreamConfig {
394 UpstreamConfig {
395 id: "test-upstream".to_string(),
396 targets: vec![UpstreamTarget {
397 address: "127.0.0.1:8081".to_string(),
398 weight: 1,
399 max_requests: None,
400 metadata: HashMap::new(),
401 }],
402 load_balancing: LoadBalancingAlgorithm::RoundRobin,
403 sticky_session: None,
404 health_check: Some(HealthCheckConfig {
405 check_type: HealthCheckType::Http {
406 path: "/health".to_string(),
407 expected_status: 200,
408 host: None,
409 },
410 interval_secs: 5,
411 timeout_secs: 2,
412 healthy_threshold: 2,
413 unhealthy_threshold: 3,
414 }),
415 connection_pool: ConnectionPoolConfig::default(),
416 timeouts: UpstreamTimeouts::default(),
417 tls: None,
418 http_version: HttpVersionConfig::default(),
419 }
420 }
421
422 #[test]
423 fn test_create_health_checker() {
424 init_crypto_provider();
425 let config = create_test_config();
426 let checker = ActiveHealthChecker::new(&config);
427 assert!(checker.is_some());
428
429 let checker = checker.unwrap();
430 assert_eq!(checker.upstream_id, "test-upstream");
431 assert_eq!(checker.interval, Duration::from_secs(5));
432 }
433
434 #[test]
435 fn test_no_health_check_config() {
436 let mut config = create_test_config();
437 config.health_check = None;
438
439 let checker = ActiveHealthChecker::new(&config);
440 assert!(checker.is_none());
441 }
442
443 #[test]
444 fn test_health_check_runner() {
445 init_crypto_provider();
446 let mut runner = HealthCheckRunner::new();
447 assert_eq!(runner.checker_count(), 0);
448
449 let config = create_test_config();
450 if let Some(checker) = ActiveHealthChecker::new(&config) {
451 runner.add_checker(checker);
452 assert_eq!(runner.checker_count(), 1);
453 }
454 }
455}