1use async_trait::async_trait;
7use pingora::upstreams::peer::HttpPeer;
8use std::collections::HashMap;
9use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
10use std::sync::Arc;
11use std::time::{Duration, Instant};
12use tokio::sync::RwLock;
13use tracing::{debug, info};
14
15use sentinel_common::{
16 errors::{SentinelError, SentinelResult},
17 types::{CircuitBreakerConfig, LoadBalancingAlgorithm, RetryPolicy},
18 CircuitBreaker, UpstreamId,
19};
20use sentinel_config::{HealthCheck as HealthCheckConfig, UpstreamConfig};
21
22#[derive(Debug, Clone)]
31pub struct UpstreamTarget {
32 pub address: String,
34 pub port: u16,
36 pub weight: u32,
38}
39
40impl UpstreamTarget {
41 pub fn new(address: impl Into<String>, port: u16, weight: u32) -> Self {
43 Self {
44 address: address.into(),
45 port,
46 weight,
47 }
48 }
49
50 pub fn from_address(addr: &str) -> Option<Self> {
52 let parts: Vec<&str> = addr.rsplitn(2, ':').collect();
53 if parts.len() == 2 {
54 let port = parts[0].parse().ok()?;
55 let address = parts[1].to_string();
56 Some(Self {
57 address,
58 port,
59 weight: 100,
60 })
61 } else {
62 None
63 }
64 }
65
66 pub fn from_config(config: &sentinel_config::UpstreamTarget) -> Option<Self> {
68 Self::from_address(&config.address).map(|mut t| {
69 t.weight = config.weight;
70 t
71 })
72 }
73
74 pub fn full_address(&self) -> String {
76 format!("{}:{}", self.address, self.port)
77 }
78}
79
80pub mod adaptive;
86pub mod consistent_hash;
87pub mod p2c;
88
89pub use adaptive::{AdaptiveBalancer, AdaptiveConfig};
91pub use consistent_hash::{
92 ConsistentHashBalancer, ConsistentHashConfig,
93};
94pub use p2c::{P2cBalancer, P2cConfig};
95
96#[derive(Debug, Clone)]
98pub struct RequestContext {
99 pub client_ip: Option<std::net::SocketAddr>,
100 pub headers: HashMap<String, String>,
101 pub path: String,
102 pub method: String,
103}
104
105#[async_trait]
107pub trait LoadBalancer: Send + Sync {
108 async fn select(&self, context: Option<&RequestContext>) -> SentinelResult<TargetSelection>;
110
111 async fn report_health(&self, address: &str, healthy: bool);
113
114 async fn healthy_targets(&self) -> Vec<String>;
116
117 async fn release(&self, _selection: &TargetSelection) {
119 }
121
122 async fn report_result(
124 &self,
125 _selection: &TargetSelection,
126 _success: bool,
127 _latency: Option<Duration>,
128 ) {
129 }
131}
132
133#[derive(Debug, Clone)]
135pub struct TargetSelection {
136 pub address: String,
138 pub weight: u32,
140 pub metadata: HashMap<String, String>,
142}
143
144pub struct UpstreamPool {
146 id: UpstreamId,
148 targets: Vec<UpstreamTarget>,
150 load_balancer: Arc<dyn LoadBalancer>,
152 health_checker: Option<Arc<UpstreamHealthChecker>>,
154 connection_pool: Arc<ConnectionPool>,
156 circuit_breakers: Arc<RwLock<HashMap<String, CircuitBreaker>>>,
158 retry_policy: Option<RetryPolicy>,
160 stats: Arc<PoolStats>,
162}
163
164pub struct UpstreamHealthChecker {
169 config: HealthCheckConfig,
171 health_status: Arc<RwLock<HashMap<String, TargetHealthStatus>>>,
173 check_handles: Arc<RwLock<Vec<tokio::task::JoinHandle<()>>>>,
175}
176
177impl UpstreamHealthChecker {
178 pub fn new(config: HealthCheckConfig) -> Self {
180 Self {
181 config,
182 health_status: Arc::new(RwLock::new(HashMap::new())),
183 check_handles: Arc::new(RwLock::new(Vec::new())),
184 }
185 }
186}
187
188#[derive(Debug, Clone)]
190struct TargetHealthStatus {
191 healthy: bool,
193 consecutive_successes: u32,
195 consecutive_failures: u32,
197 last_check: Instant,
199 last_success: Option<Instant>,
201 last_error: Option<String>,
203}
204
205pub struct ConnectionPool {
207 max_connections: usize,
209 max_idle: usize,
210 idle_timeout: Duration,
211 max_lifetime: Option<Duration>,
212 connections: Arc<RwLock<HashMap<String, Vec<PooledConnection>>>>,
214 stats: Arc<ConnectionPoolStats>,
216}
217
218impl ConnectionPool {
219 pub fn new(
221 max_connections: usize,
222 max_idle: usize,
223 idle_timeout: Duration,
224 max_lifetime: Option<Duration>,
225 ) -> Self {
226 Self {
227 max_connections,
228 max_idle,
229 idle_timeout,
230 max_lifetime,
231 connections: Arc::new(RwLock::new(HashMap::new())),
232 stats: Arc::new(ConnectionPoolStats::default()),
233 }
234 }
235
236 pub async fn acquire(&self, _address: &str) -> SentinelResult<Option<HttpPeer>> {
238 Ok(None)
241 }
242
243 pub async fn close_all(&self) {
245 let mut connections = self.connections.write().await;
246 connections.clear();
247 }
248}
249
250struct PooledConnection {
252 peer: HttpPeer,
254 created: Instant,
256 last_used: Instant,
258 in_use: bool,
260}
261
262#[derive(Default)]
264struct ConnectionPoolStats {
265 created: AtomicU64,
267 reused: AtomicU64,
269 closed: AtomicU64,
271 active: AtomicU64,
273 idle: AtomicU64,
275}
276
277#[derive(Default)]
281pub struct PoolStats {
282 pub requests: AtomicU64,
284 pub successes: AtomicU64,
286 pub failures: AtomicU64,
288 pub retries: AtomicU64,
290 pub circuit_breaker_trips: AtomicU64,
292}
293
294struct RoundRobinBalancer {
296 targets: Vec<UpstreamTarget>,
297 current: AtomicUsize,
298 health_status: Arc<RwLock<HashMap<String, bool>>>,
299}
300
301impl RoundRobinBalancer {
302 fn new(targets: Vec<UpstreamTarget>) -> Self {
303 let mut health_status = HashMap::new();
304 for target in &targets {
305 health_status.insert(target.full_address(), true);
306 }
307
308 Self {
309 targets,
310 current: AtomicUsize::new(0),
311 health_status: Arc::new(RwLock::new(health_status)),
312 }
313 }
314}
315
316#[async_trait]
317impl LoadBalancer for RoundRobinBalancer {
318 async fn select(&self, _context: Option<&RequestContext>) -> SentinelResult<TargetSelection> {
319 let health = self.health_status.read().await;
320 let healthy_targets: Vec<_> = self
321 .targets
322 .iter()
323 .filter(|t| *health.get(&t.full_address()).unwrap_or(&true))
324 .collect();
325
326 if healthy_targets.is_empty() {
327 return Err(SentinelError::NoHealthyUpstream);
328 }
329
330 let index = self.current.fetch_add(1, Ordering::Relaxed) % healthy_targets.len();
331 let target = healthy_targets[index];
332
333 Ok(TargetSelection {
334 address: target.full_address(),
335 weight: target.weight,
336 metadata: HashMap::new(),
337 })
338 }
339
340 async fn report_health(&self, address: &str, healthy: bool) {
341 self.health_status
342 .write()
343 .await
344 .insert(address.to_string(), healthy);
345 }
346
347 async fn healthy_targets(&self) -> Vec<String> {
348 self.health_status
349 .read()
350 .await
351 .iter()
352 .filter_map(|(addr, &healthy)| if healthy { Some(addr.clone()) } else { None })
353 .collect()
354 }
355}
356
357struct LeastConnectionsBalancer {
359 targets: Vec<UpstreamTarget>,
360 connections: Arc<RwLock<HashMap<String, usize>>>,
361 health_status: Arc<RwLock<HashMap<String, bool>>>,
362}
363
364impl LeastConnectionsBalancer {
365 fn new(targets: Vec<UpstreamTarget>) -> Self {
366 let mut health_status = HashMap::new();
367 let mut connections = HashMap::new();
368
369 for target in &targets {
370 let addr = target.full_address();
371 health_status.insert(addr.clone(), true);
372 connections.insert(addr, 0);
373 }
374
375 Self {
376 targets,
377 connections: Arc::new(RwLock::new(connections)),
378 health_status: Arc::new(RwLock::new(health_status)),
379 }
380 }
381}
382
383#[async_trait]
384impl LoadBalancer for LeastConnectionsBalancer {
385 async fn select(&self, _context: Option<&RequestContext>) -> SentinelResult<TargetSelection> {
386 let health = self.health_status.read().await;
387 let conns = self.connections.read().await;
388
389 let mut best_target = None;
390 let mut min_connections = usize::MAX;
391
392 for target in &self.targets {
393 let addr = target.full_address();
394 if !*health.get(&addr).unwrap_or(&true) {
395 continue;
396 }
397
398 let conn_count = *conns.get(&addr).unwrap_or(&0);
399 if conn_count < min_connections {
400 min_connections = conn_count;
401 best_target = Some(target);
402 }
403 }
404
405 best_target
406 .map(|target| TargetSelection {
407 address: target.full_address(),
408 weight: target.weight,
409 metadata: HashMap::new(),
410 })
411 .ok_or(SentinelError::NoHealthyUpstream)
412 }
413
414 async fn report_health(&self, address: &str, healthy: bool) {
415 self.health_status
416 .write()
417 .await
418 .insert(address.to_string(), healthy);
419 }
420
421 async fn healthy_targets(&self) -> Vec<String> {
422 self.health_status
423 .read()
424 .await
425 .iter()
426 .filter_map(|(addr, &healthy)| if healthy { Some(addr.clone()) } else { None })
427 .collect()
428 }
429}
430
431struct WeightedBalancer {
433 targets: Vec<UpstreamTarget>,
434 weights: Vec<u32>,
435 current_index: AtomicUsize,
436 health_status: Arc<RwLock<HashMap<String, bool>>>,
437}
438
439#[async_trait]
440impl LoadBalancer for WeightedBalancer {
441 async fn select(&self, _context: Option<&RequestContext>) -> SentinelResult<TargetSelection> {
442 let health = self.health_status.read().await;
443 let healthy_indices: Vec<_> = self
444 .targets
445 .iter()
446 .enumerate()
447 .filter(|(_, t)| *health.get(&t.full_address()).unwrap_or(&true))
448 .map(|(i, _)| i)
449 .collect();
450
451 if healthy_indices.is_empty() {
452 return Err(SentinelError::NoHealthyUpstream);
453 }
454
455 let idx = self.current_index.fetch_add(1, Ordering::Relaxed) % healthy_indices.len();
456 let target_idx = healthy_indices[idx];
457 let target = &self.targets[target_idx];
458
459 Ok(TargetSelection {
460 address: target.full_address(),
461 weight: self.weights.get(target_idx).copied().unwrap_or(1),
462 metadata: HashMap::new(),
463 })
464 }
465
466 async fn report_health(&self, address: &str, healthy: bool) {
467 self.health_status
468 .write()
469 .await
470 .insert(address.to_string(), healthy);
471 }
472
473 async fn healthy_targets(&self) -> Vec<String> {
474 self.health_status
475 .read()
476 .await
477 .iter()
478 .filter_map(|(addr, &healthy)| if healthy { Some(addr.clone()) } else { None })
479 .collect()
480 }
481}
482
483struct IpHashBalancer {
485 targets: Vec<UpstreamTarget>,
486 health_status: Arc<RwLock<HashMap<String, bool>>>,
487}
488
489#[async_trait]
490impl LoadBalancer for IpHashBalancer {
491 async fn select(&self, context: Option<&RequestContext>) -> SentinelResult<TargetSelection> {
492 let health = self.health_status.read().await;
493 let healthy_targets: Vec<_> = self
494 .targets
495 .iter()
496 .filter(|t| *health.get(&t.full_address()).unwrap_or(&true))
497 .collect();
498
499 if healthy_targets.is_empty() {
500 return Err(SentinelError::NoHealthyUpstream);
501 }
502
503 let hash = if let Some(ctx) = context {
505 if let Some(ip) = &ctx.client_ip {
506 use std::hash::{Hash, Hasher};
507 let mut hasher = std::collections::hash_map::DefaultHasher::new();
508 ip.hash(&mut hasher);
509 hasher.finish()
510 } else {
511 0
512 }
513 } else {
514 0
515 };
516
517 let idx = (hash as usize) % healthy_targets.len();
518 let target = healthy_targets[idx];
519
520 Ok(TargetSelection {
521 address: target.full_address(),
522 weight: target.weight,
523 metadata: HashMap::new(),
524 })
525 }
526
527 async fn report_health(&self, address: &str, healthy: bool) {
528 self.health_status
529 .write()
530 .await
531 .insert(address.to_string(), healthy);
532 }
533
534 async fn healthy_targets(&self) -> Vec<String> {
535 self.health_status
536 .read()
537 .await
538 .iter()
539 .filter_map(|(addr, &healthy)| if healthy { Some(addr.clone()) } else { None })
540 .collect()
541 }
542}
543
544impl UpstreamPool {
545 pub async fn new(config: UpstreamConfig) -> SentinelResult<Self> {
547 let id = UpstreamId::new(&config.id);
548
549 let targets: Vec<UpstreamTarget> = config
551 .targets
552 .iter()
553 .filter_map(|t| UpstreamTarget::from_config(t))
554 .collect();
555
556 if targets.is_empty() {
557 return Err(SentinelError::Config {
558 message: "No valid upstream targets".to_string(),
559 source: None,
560 });
561 }
562
563 let load_balancer = Self::create_load_balancer(&config.load_balancing, &targets)?;
565
566 let health_checker = config
568 .health_check
569 .as_ref()
570 .map(|hc_config| Arc::new(UpstreamHealthChecker::new(hc_config.clone())));
571
572 let connection_pool = Arc::new(ConnectionPool::new(
574 config.connection_pool.max_connections,
575 config.connection_pool.max_idle,
576 Duration::from_secs(config.connection_pool.idle_timeout_secs),
577 config
578 .connection_pool
579 .max_lifetime_secs
580 .map(Duration::from_secs),
581 ));
582
583 let mut circuit_breakers = HashMap::new();
585 for target in &targets {
586 circuit_breakers.insert(
587 target.full_address(),
588 CircuitBreaker::new(CircuitBreakerConfig::default()),
589 );
590 }
591
592 let pool = Self {
593 id,
594 targets,
595 load_balancer,
596 health_checker,
597 connection_pool,
598 circuit_breakers: Arc::new(RwLock::new(circuit_breakers)),
599 retry_policy: None,
600 stats: Arc::new(PoolStats::default()),
601 };
602
603 Ok(pool)
604 }
605
606 fn create_load_balancer(
608 algorithm: &LoadBalancingAlgorithm,
609 targets: &[UpstreamTarget],
610 ) -> SentinelResult<Arc<dyn LoadBalancer>> {
611 let balancer: Arc<dyn LoadBalancer> = match algorithm {
612 LoadBalancingAlgorithm::RoundRobin => {
613 Arc::new(RoundRobinBalancer::new(targets.to_vec()))
614 }
615 LoadBalancingAlgorithm::LeastConnections => {
616 Arc::new(LeastConnectionsBalancer::new(targets.to_vec()))
617 }
618 LoadBalancingAlgorithm::Weighted => {
619 let weights: Vec<u32> = targets.iter().map(|t| t.weight).collect();
620 Arc::new(WeightedBalancer {
621 targets: targets.to_vec(),
622 weights,
623 current_index: AtomicUsize::new(0),
624 health_status: Arc::new(RwLock::new(HashMap::new())),
625 })
626 }
627 LoadBalancingAlgorithm::IpHash => Arc::new(IpHashBalancer {
628 targets: targets.to_vec(),
629 health_status: Arc::new(RwLock::new(HashMap::new())),
630 }),
631 LoadBalancingAlgorithm::Random => {
632 Arc::new(RoundRobinBalancer::new(targets.to_vec()))
633 }
634 LoadBalancingAlgorithm::ConsistentHash => Arc::new(ConsistentHashBalancer::new(
635 targets.to_vec(),
636 ConsistentHashConfig::default(),
637 )),
638 LoadBalancingAlgorithm::PowerOfTwoChoices => {
639 Arc::new(P2cBalancer::new(targets.to_vec(), P2cConfig::default()))
640 }
641 LoadBalancingAlgorithm::Adaptive => Arc::new(AdaptiveBalancer::new(
642 targets.to_vec(),
643 AdaptiveConfig::default(),
644 )),
645 };
646 Ok(balancer)
647 }
648
649 pub async fn select_peer(&self, context: Option<&RequestContext>) -> SentinelResult<HttpPeer> {
651 self.stats.requests.fetch_add(1, Ordering::Relaxed);
652
653 let mut attempts = 0;
654 let max_attempts = self.targets.len() * 2;
655
656 while attempts < max_attempts {
657 attempts += 1;
658
659 let selection = self.load_balancer.select(context).await?;
660
661 let breakers = self.circuit_breakers.read().await;
663 if let Some(breaker) = breakers.get(&selection.address) {
664 if !breaker.is_closed().await {
665 debug!(
666 target = %selection.address,
667 "Circuit breaker is open, skipping target"
668 );
669 continue;
670 }
671 }
672
673 if let Some(peer) = self.connection_pool.acquire(&selection.address).await? {
675 debug!(target = %selection.address, "Reusing pooled connection");
676 return Ok(peer);
677 }
678
679 debug!(target = %selection.address, "Creating new connection");
681 let peer = self.create_peer(&selection)?;
682
683 self.stats.successes.fetch_add(1, Ordering::Relaxed);
684 return Ok(peer);
685 }
686
687 self.stats.failures.fetch_add(1, Ordering::Relaxed);
688 Err(SentinelError::upstream(
689 &self.id.to_string(),
690 "Failed to select upstream after max attempts",
691 ))
692 }
693
694 fn create_peer(&self, selection: &TargetSelection) -> SentinelResult<HttpPeer> {
696 let peer = HttpPeer::new(
697 &selection.address,
698 false,
699 String::new(),
700 );
701 Ok(peer)
702 }
703
704 pub async fn report_result(&self, target: &str, success: bool) {
706 if success {
707 if let Some(breaker) = self.circuit_breakers.read().await.get(target) {
708 breaker.record_success().await;
709 }
710 self.load_balancer.report_health(target, true).await;
711 } else {
712 if let Some(breaker) = self.circuit_breakers.read().await.get(target) {
713 breaker.record_failure().await;
714 }
715 self.load_balancer.report_health(target, false).await;
716 self.stats.failures.fetch_add(1, Ordering::Relaxed);
717 }
718 }
719
720 pub fn stats(&self) -> &PoolStats {
722 &self.stats
723 }
724
725 pub async fn shutdown(&self) {
727 info!("Shutting down upstream pool: {}", self.id);
728 self.connection_pool.close_all().await;
729 }
730}