1use pingora_load_balancing::{
8 discovery::Static,
9 health_check::{HealthCheck as PingoraHealthCheck, HttpHealthCheck, TcpHealthCheck},
10 Backend, Backends,
11};
12use std::collections::BTreeSet;
13use std::sync::Arc;
14use std::time::Duration;
15use tokio::sync::RwLock;
16use tracing::{debug, info, trace, warn};
17
18use crate::grpc_health::GrpcHealthCheck;
19
20use sentinel_common::types::HealthCheckType;
21use sentinel_config::{HealthCheck as HealthCheckConfig, UpstreamConfig};
22
23pub struct ActiveHealthChecker {
28 upstream_id: String,
30 backends: Arc<Backends>,
32 interval: Duration,
34 parallel: bool,
36 health_callback: Arc<RwLock<Option<HealthChangeCallback>>>,
38}
39
40pub type HealthChangeCallback = Box<dyn Fn(&str, bool) + Send + Sync>;
42
43impl ActiveHealthChecker {
44 pub fn new(config: &UpstreamConfig) -> Option<Self> {
46 let health_config = config.health_check.as_ref()?;
47
48 info!(
49 upstream_id = %config.id,
50 check_type = ?health_config.check_type,
51 interval_secs = health_config.interval_secs,
52 "Creating active health checker"
53 );
54
55 let mut backend_set = BTreeSet::new();
57 for target in &config.targets {
58 match Backend::new_with_weight(&target.address, target.weight as usize) {
59 Ok(backend) => {
60 debug!(
61 upstream_id = %config.id,
62 target = %target.address,
63 weight = target.weight,
64 "Added backend for health checking"
65 );
66 backend_set.insert(backend);
67 }
68 Err(e) => {
69 warn!(
70 upstream_id = %config.id,
71 target = %target.address,
72 error = %e,
73 "Failed to create backend for health checking"
74 );
75 }
76 }
77 }
78
79 if backend_set.is_empty() {
80 warn!(
81 upstream_id = %config.id,
82 "No backends created for health checking"
83 );
84 return None;
85 }
86
87 let discovery = Static::new(backend_set);
89 let mut backends = Backends::new(discovery);
90
91 let health_check: Box<dyn PingoraHealthCheck + Send + Sync> =
93 Self::create_health_check(health_config, &config.id);
94
95 backends.set_health_check(health_check);
96
97 Some(Self {
98 upstream_id: config.id.clone(),
99 backends: Arc::new(backends),
100 interval: Duration::from_secs(health_config.interval_secs),
101 parallel: true,
102 health_callback: Arc::new(RwLock::new(None)),
103 })
104 }
105
106 fn create_health_check(
108 config: &HealthCheckConfig,
109 upstream_id: &str,
110 ) -> Box<dyn PingoraHealthCheck + Send + Sync> {
111 match &config.check_type {
112 HealthCheckType::Http {
113 path,
114 expected_status,
115 host,
116 } => {
117 let hostname = host.as_deref().unwrap_or("localhost");
118 let mut hc = HttpHealthCheck::new(hostname, false);
119
120 hc.consecutive_success = config.healthy_threshold as usize;
122 hc.consecutive_failure = config.unhealthy_threshold as usize;
123
124 if path != "/" {
128 if let Ok(req) =
130 pingora_http::RequestHeader::build("GET", path.as_bytes(), None)
131 {
132 hc.req = req;
133 }
134 }
135
136 debug!(
140 upstream_id = %upstream_id,
141 path = %path,
142 expected_status = expected_status,
143 host = hostname,
144 consecutive_success = hc.consecutive_success,
145 consecutive_failure = hc.consecutive_failure,
146 "Created HTTP health check"
147 );
148
149 Box::new(hc)
150 }
151 HealthCheckType::Tcp => {
152 let mut hc = TcpHealthCheck::new();
154 hc.consecutive_success = config.healthy_threshold as usize;
155 hc.consecutive_failure = config.unhealthy_threshold as usize;
156
157 debug!(
158 upstream_id = %upstream_id,
159 consecutive_success = hc.consecutive_success,
160 consecutive_failure = hc.consecutive_failure,
161 "Created TCP health check"
162 );
163
164 hc
165 }
166 HealthCheckType::Grpc { service } => {
167 let timeout = Duration::from_secs(config.timeout_secs);
168 let mut hc = GrpcHealthCheck::new(service.clone(), timeout);
169 hc.consecutive_success = config.healthy_threshold as usize;
170 hc.consecutive_failure = config.unhealthy_threshold as usize;
171
172 info!(
173 upstream_id = %upstream_id,
174 service = %service,
175 timeout_secs = config.timeout_secs,
176 consecutive_success = hc.consecutive_success,
177 consecutive_failure = hc.consecutive_failure,
178 "Created gRPC health check"
179 );
180
181 Box::new(hc)
182 }
183 }
184 }
185
186 pub async fn set_health_callback(&self, callback: HealthChangeCallback) {
188 *self.health_callback.write().await = Some(callback);
189 }
190
191 pub async fn run_health_check(&self) {
193 trace!(
194 upstream_id = %self.upstream_id,
195 parallel = self.parallel,
196 "Running health check cycle"
197 );
198
199 self.backends.run_health_check(self.parallel).await;
200 }
201
202 pub fn is_backend_healthy(&self, address: &str) -> bool {
204 let backends = self.backends.get_backend();
205 for backend in backends.iter() {
206 if backend.addr.to_string() == address {
207 return self.backends.ready(backend);
208 }
209 }
210 true
212 }
213
214 pub fn get_health_statuses(&self) -> Vec<(String, bool)> {
216 let backends = self.backends.get_backend();
217 backends
218 .iter()
219 .map(|b| {
220 let addr = b.addr.to_string();
221 let healthy = self.backends.ready(b);
222 (addr, healthy)
223 })
224 .collect()
225 }
226
227 pub fn interval(&self) -> Duration {
229 self.interval
230 }
231
232 pub fn upstream_id(&self) -> &str {
234 &self.upstream_id
235 }
236}
237
238pub struct HealthCheckRunner {
240 checkers: Vec<ActiveHealthChecker>,
242 running: Arc<RwLock<bool>>,
244}
245
246impl HealthCheckRunner {
247 pub fn new() -> Self {
249 Self {
250 checkers: Vec::new(),
251 running: Arc::new(RwLock::new(false)),
252 }
253 }
254
255 pub fn add_checker(&mut self, checker: ActiveHealthChecker) {
257 info!(
258 upstream_id = %checker.upstream_id,
259 interval_secs = checker.interval.as_secs(),
260 "Added health checker to runner"
261 );
262 self.checkers.push(checker);
263 }
264
265 pub fn checker_count(&self) -> usize {
267 self.checkers.len()
268 }
269
270 pub async fn run(&self) {
272 if self.checkers.is_empty() {
273 info!("No health checkers configured, skipping health check loop");
274 return;
275 }
276
277 *self.running.write().await = true;
278
279 info!(
280 checker_count = self.checkers.len(),
281 "Starting health check runner"
282 );
283
284 let min_interval = self
286 .checkers
287 .iter()
288 .map(|c| c.interval)
289 .min()
290 .unwrap_or(Duration::from_secs(10));
291
292 let mut interval = tokio::time::interval(min_interval);
293 interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
294
295 loop {
296 interval.tick().await;
297
298 if !*self.running.read().await {
299 info!("Health check runner stopped");
300 break;
301 }
302
303 for checker in &self.checkers {
305 checker.run_health_check().await;
306
307 let statuses = checker.get_health_statuses();
309 for (addr, healthy) in &statuses {
310 trace!(
311 upstream_id = %checker.upstream_id,
312 backend = %addr,
313 healthy = healthy,
314 "Backend health status"
315 );
316 }
317 }
318 }
319 }
320
321 pub async fn stop(&self) {
323 info!("Stopping health check runner");
324 *self.running.write().await = false;
325 }
326
327 pub fn get_health(&self, upstream_id: &str, address: &str) -> Option<bool> {
329 self.checkers
330 .iter()
331 .find(|c| c.upstream_id == upstream_id)
332 .map(|c| c.is_backend_healthy(address))
333 }
334
335 pub fn get_upstream_health(&self, upstream_id: &str) -> Option<Vec<(String, bool)>> {
337 self.checkers
338 .iter()
339 .find(|c| c.upstream_id == upstream_id)
340 .map(|c| c.get_health_statuses())
341 }
342}
343
344impl Default for HealthCheckRunner {
345 fn default() -> Self {
346 Self::new()
347 }
348}
349
350#[cfg(test)]
351mod tests {
352 use super::*;
353 use sentinel_common::types::LoadBalancingAlgorithm;
354 use sentinel_config::{
355 ConnectionPoolConfig, HttpVersionConfig, UpstreamTarget, UpstreamTimeouts,
356 };
357 use std::collections::HashMap;
358
359 fn create_test_config() -> UpstreamConfig {
360 UpstreamConfig {
361 id: "test-upstream".to_string(),
362 targets: vec![UpstreamTarget {
363 address: "127.0.0.1:8081".to_string(),
364 weight: 1,
365 max_requests: None,
366 metadata: HashMap::new(),
367 }],
368 load_balancing: LoadBalancingAlgorithm::RoundRobin,
369 health_check: Some(HealthCheckConfig {
370 check_type: HealthCheckType::Http {
371 path: "/health".to_string(),
372 expected_status: 200,
373 host: None,
374 },
375 interval_secs: 5,
376 timeout_secs: 2,
377 healthy_threshold: 2,
378 unhealthy_threshold: 3,
379 }),
380 connection_pool: ConnectionPoolConfig::default(),
381 timeouts: UpstreamTimeouts::default(),
382 tls: None,
383 http_version: HttpVersionConfig::default(),
384 }
385 }
386
387 #[test]
388 fn test_create_health_checker() {
389 let config = create_test_config();
390 let checker = ActiveHealthChecker::new(&config);
391 assert!(checker.is_some());
392
393 let checker = checker.unwrap();
394 assert_eq!(checker.upstream_id, "test-upstream");
395 assert_eq!(checker.interval, Duration::from_secs(5));
396 }
397
398 #[test]
399 fn test_no_health_check_config() {
400 let mut config = create_test_config();
401 config.health_check = None;
402
403 let checker = ActiveHealthChecker::new(&config);
404 assert!(checker.is_none());
405 }
406
407 #[test]
408 fn test_health_check_runner() {
409 let mut runner = HealthCheckRunner::new();
410 assert_eq!(runner.checker_count(), 0);
411
412 let config = create_test_config();
413 if let Some(checker) = ActiveHealthChecker::new(&config) {
414 runner.add_checker(checker);
415 assert_eq!(runner.checker_count(), 1);
416 }
417 }
418}