1use pingora_load_balancing::{
8 discovery::Static,
9 health_check::{HealthCheck as PingoraHealthCheck, HttpHealthCheck, TcpHealthCheck},
10 Backend, Backends,
11};
12use std::collections::BTreeSet;
13use std::sync::Arc;
14use std::time::Duration;
15use tokio::sync::RwLock;
16use tracing::{debug, info, trace, warn};
17
18use sentinel_common::types::HealthCheckType;
19use sentinel_config::{HealthCheck as HealthCheckConfig, UpstreamConfig};
20
21pub struct ActiveHealthChecker {
26 upstream_id: String,
28 backends: Arc<Backends>,
30 interval: Duration,
32 parallel: bool,
34 health_callback: Arc<RwLock<Option<HealthChangeCallback>>>,
36}
37
38pub type HealthChangeCallback = Box<dyn Fn(&str, bool) + Send + Sync>;
40
41impl ActiveHealthChecker {
42 pub fn new(config: &UpstreamConfig) -> Option<Self> {
44 let health_config = config.health_check.as_ref()?;
45
46 info!(
47 upstream_id = %config.id,
48 check_type = ?health_config.check_type,
49 interval_secs = health_config.interval_secs,
50 "Creating active health checker"
51 );
52
53 let mut backend_set = BTreeSet::new();
55 for target in &config.targets {
56 match Backend::new_with_weight(&target.address, target.weight as usize) {
57 Ok(backend) => {
58 debug!(
59 upstream_id = %config.id,
60 target = %target.address,
61 weight = target.weight,
62 "Added backend for health checking"
63 );
64 backend_set.insert(backend);
65 }
66 Err(e) => {
67 warn!(
68 upstream_id = %config.id,
69 target = %target.address,
70 error = %e,
71 "Failed to create backend for health checking"
72 );
73 }
74 }
75 }
76
77 if backend_set.is_empty() {
78 warn!(
79 upstream_id = %config.id,
80 "No backends created for health checking"
81 );
82 return None;
83 }
84
85 let discovery = Static::new(backend_set);
87 let mut backends = Backends::new(discovery);
88
89 let health_check: Box<dyn PingoraHealthCheck + Send + Sync> =
91 Self::create_health_check(health_config, &config.id);
92
93 backends.set_health_check(health_check);
94
95 Some(Self {
96 upstream_id: config.id.clone(),
97 backends: Arc::new(backends),
98 interval: Duration::from_secs(health_config.interval_secs),
99 parallel: true,
100 health_callback: Arc::new(RwLock::new(None)),
101 })
102 }
103
104 fn create_health_check(
106 config: &HealthCheckConfig,
107 upstream_id: &str,
108 ) -> Box<dyn PingoraHealthCheck + Send + Sync> {
109 match &config.check_type {
110 HealthCheckType::Http {
111 path,
112 expected_status,
113 host,
114 } => {
115 let hostname = host.as_deref().unwrap_or("localhost");
116 let mut hc = HttpHealthCheck::new(hostname, false);
117
118 hc.consecutive_success = config.healthy_threshold as usize;
120 hc.consecutive_failure = config.unhealthy_threshold as usize;
121
122 if path != "/" {
126 if let Ok(req) =
128 pingora_http::RequestHeader::build("GET", path.as_bytes(), None)
129 {
130 hc.req = req;
131 }
132 }
133
134 debug!(
138 upstream_id = %upstream_id,
139 path = %path,
140 expected_status = expected_status,
141 host = hostname,
142 consecutive_success = hc.consecutive_success,
143 consecutive_failure = hc.consecutive_failure,
144 "Created HTTP health check"
145 );
146
147 Box::new(hc)
148 }
149 HealthCheckType::Tcp => {
150 let mut hc = TcpHealthCheck::new();
152 hc.consecutive_success = config.healthy_threshold as usize;
153 hc.consecutive_failure = config.unhealthy_threshold as usize;
154
155 debug!(
156 upstream_id = %upstream_id,
157 consecutive_success = hc.consecutive_success,
158 consecutive_failure = hc.consecutive_failure,
159 "Created TCP health check"
160 );
161
162 hc
163 }
164 HealthCheckType::Grpc { service } => {
165 warn!(
167 upstream_id = %upstream_id,
168 service = %service,
169 "gRPC health check not supported, falling back to TCP"
170 );
171 let mut hc = TcpHealthCheck::new();
173 hc.consecutive_success = config.healthy_threshold as usize;
174 hc.consecutive_failure = config.unhealthy_threshold as usize;
175 hc
176 }
177 }
178 }
179
180 pub async fn set_health_callback(&self, callback: HealthChangeCallback) {
182 *self.health_callback.write().await = Some(callback);
183 }
184
185 pub async fn run_health_check(&self) {
187 trace!(
188 upstream_id = %self.upstream_id,
189 parallel = self.parallel,
190 "Running health check cycle"
191 );
192
193 self.backends.run_health_check(self.parallel).await;
194 }
195
196 pub fn is_backend_healthy(&self, address: &str) -> bool {
198 let backends = self.backends.get_backend();
199 for backend in backends.iter() {
200 if backend.addr.to_string() == address {
201 return self.backends.ready(backend);
202 }
203 }
204 true
206 }
207
208 pub fn get_health_statuses(&self) -> Vec<(String, bool)> {
210 let backends = self.backends.get_backend();
211 backends
212 .iter()
213 .map(|b| {
214 let addr = b.addr.to_string();
215 let healthy = self.backends.ready(b);
216 (addr, healthy)
217 })
218 .collect()
219 }
220
221 pub fn interval(&self) -> Duration {
223 self.interval
224 }
225
226 pub fn upstream_id(&self) -> &str {
228 &self.upstream_id
229 }
230}
231
232pub struct HealthCheckRunner {
234 checkers: Vec<ActiveHealthChecker>,
236 running: Arc<RwLock<bool>>,
238}
239
240impl HealthCheckRunner {
241 pub fn new() -> Self {
243 Self {
244 checkers: Vec::new(),
245 running: Arc::new(RwLock::new(false)),
246 }
247 }
248
249 pub fn add_checker(&mut self, checker: ActiveHealthChecker) {
251 info!(
252 upstream_id = %checker.upstream_id,
253 interval_secs = checker.interval.as_secs(),
254 "Added health checker to runner"
255 );
256 self.checkers.push(checker);
257 }
258
259 pub fn checker_count(&self) -> usize {
261 self.checkers.len()
262 }
263
264 pub async fn run(&self) {
266 if self.checkers.is_empty() {
267 info!("No health checkers configured, skipping health check loop");
268 return;
269 }
270
271 *self.running.write().await = true;
272
273 info!(
274 checker_count = self.checkers.len(),
275 "Starting health check runner"
276 );
277
278 let min_interval = self
280 .checkers
281 .iter()
282 .map(|c| c.interval)
283 .min()
284 .unwrap_or(Duration::from_secs(10));
285
286 let mut interval = tokio::time::interval(min_interval);
287 interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
288
289 loop {
290 interval.tick().await;
291
292 if !*self.running.read().await {
293 info!("Health check runner stopped");
294 break;
295 }
296
297 for checker in &self.checkers {
299 checker.run_health_check().await;
300
301 let statuses = checker.get_health_statuses();
303 for (addr, healthy) in &statuses {
304 trace!(
305 upstream_id = %checker.upstream_id,
306 backend = %addr,
307 healthy = healthy,
308 "Backend health status"
309 );
310 }
311 }
312 }
313 }
314
315 pub async fn stop(&self) {
317 info!("Stopping health check runner");
318 *self.running.write().await = false;
319 }
320
321 pub fn get_health(&self, upstream_id: &str, address: &str) -> Option<bool> {
323 self.checkers
324 .iter()
325 .find(|c| c.upstream_id == upstream_id)
326 .map(|c| c.is_backend_healthy(address))
327 }
328
329 pub fn get_upstream_health(&self, upstream_id: &str) -> Option<Vec<(String, bool)>> {
331 self.checkers
332 .iter()
333 .find(|c| c.upstream_id == upstream_id)
334 .map(|c| c.get_health_statuses())
335 }
336}
337
338impl Default for HealthCheckRunner {
339 fn default() -> Self {
340 Self::new()
341 }
342}
343
344#[cfg(test)]
345mod tests {
346 use super::*;
347 use sentinel_common::types::LoadBalancingAlgorithm;
348 use sentinel_config::{
349 ConnectionPoolConfig, HttpVersionConfig, UpstreamTarget, UpstreamTimeouts,
350 };
351 use std::collections::HashMap;
352
353 fn create_test_config() -> UpstreamConfig {
354 UpstreamConfig {
355 id: "test-upstream".to_string(),
356 targets: vec![UpstreamTarget {
357 address: "127.0.0.1:8081".to_string(),
358 weight: 1,
359 max_requests: None,
360 metadata: HashMap::new(),
361 }],
362 load_balancing: LoadBalancingAlgorithm::RoundRobin,
363 health_check: Some(HealthCheckConfig {
364 check_type: HealthCheckType::Http {
365 path: "/health".to_string(),
366 expected_status: 200,
367 host: None,
368 },
369 interval_secs: 5,
370 timeout_secs: 2,
371 healthy_threshold: 2,
372 unhealthy_threshold: 3,
373 }),
374 connection_pool: ConnectionPoolConfig::default(),
375 timeouts: UpstreamTimeouts::default(),
376 tls: None,
377 http_version: HttpVersionConfig::default(),
378 }
379 }
380
381 #[test]
382 fn test_create_health_checker() {
383 let config = create_test_config();
384 let checker = ActiveHealthChecker::new(&config);
385 assert!(checker.is_some());
386
387 let checker = checker.unwrap();
388 assert_eq!(checker.upstream_id, "test-upstream");
389 assert_eq!(checker.interval, Duration::from_secs(5));
390 }
391
392 #[test]
393 fn test_no_health_check_config() {
394 let mut config = create_test_config();
395 config.health_check = None;
396
397 let checker = ActiveHealthChecker::new(&config);
398 assert!(checker.is_none());
399 }
400
401 #[test]
402 fn test_health_check_runner() {
403 let mut runner = HealthCheckRunner::new();
404 assert_eq!(runner.checker_count(), 0);
405
406 let config = create_test_config();
407 if let Some(checker) = ActiveHealthChecker::new(&config) {
408 runner.add_checker(checker);
409 assert_eq!(runner.checker_count(), 1);
410 }
411 }
412}