1use async_trait::async_trait;
7use pingora::upstreams::peer::HttpPeer;
8use rand::seq::IndexedRandom;
9use std::collections::HashMap;
10use std::net::ToSocketAddrs;
11use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
12use std::sync::Arc;
13use std::time::Duration;
14use tokio::sync::RwLock;
15use tracing::{debug, error, info, trace, warn};
16
17use sentinel_common::{
18 errors::{SentinelError, SentinelResult},
19 types::{CircuitBreakerConfig, LoadBalancingAlgorithm},
20 CircuitBreaker, UpstreamId,
21};
22use sentinel_config::UpstreamConfig;
23
24#[derive(Debug, Clone)]
33pub struct UpstreamTarget {
34 pub address: String,
36 pub port: u16,
38 pub weight: u32,
40}
41
42impl UpstreamTarget {
43 pub fn new(address: impl Into<String>, port: u16, weight: u32) -> Self {
45 Self {
46 address: address.into(),
47 port,
48 weight,
49 }
50 }
51
52 pub fn from_address(addr: &str) -> Option<Self> {
54 let parts: Vec<&str> = addr.rsplitn(2, ':').collect();
55 if parts.len() == 2 {
56 let port = parts[0].parse().ok()?;
57 let address = parts[1].to_string();
58 Some(Self {
59 address,
60 port,
61 weight: 100,
62 })
63 } else {
64 None
65 }
66 }
67
68 pub fn from_config(config: &sentinel_config::UpstreamTarget) -> Option<Self> {
70 Self::from_address(&config.address).map(|mut t| {
71 t.weight = config.weight;
72 t
73 })
74 }
75
76 pub fn full_address(&self) -> String {
78 format!("{}:{}", self.address, self.port)
79 }
80}
81
82pub mod adaptive;
88pub mod consistent_hash;
89pub mod health;
90pub mod inference_health;
91pub mod least_tokens;
92pub mod locality;
93pub mod maglev;
94pub mod p2c;
95pub mod peak_ewma;
96pub mod sticky_session;
97pub mod subset;
98pub mod weighted_least_conn;
99
100pub use adaptive::{AdaptiveBalancer, AdaptiveConfig};
102pub use consistent_hash::{ConsistentHashBalancer, ConsistentHashConfig};
103pub use health::{ActiveHealthChecker, HealthCheckRunner};
104pub use inference_health::InferenceHealthCheck;
105pub use least_tokens::{LeastTokensQueuedBalancer, LeastTokensQueuedConfig, LeastTokensQueuedTargetStats};
106pub use locality::{LocalityAwareBalancer, LocalityAwareConfig};
107pub use maglev::{MaglevBalancer, MaglevConfig};
108pub use p2c::{P2cBalancer, P2cConfig};
109pub use peak_ewma::{PeakEwmaBalancer, PeakEwmaConfig};
110pub use sticky_session::{StickySessionBalancer, StickySessionRuntimeConfig};
111pub use subset::{SubsetBalancer, SubsetConfig};
112pub use weighted_least_conn::{WeightedLeastConnBalancer, WeightedLeastConnConfig};
113
114#[derive(Debug, Clone)]
116pub struct RequestContext {
117 pub client_ip: Option<std::net::SocketAddr>,
118 pub headers: HashMap<String, String>,
119 pub path: String,
120 pub method: String,
121}
122
123#[async_trait]
125pub trait LoadBalancer: Send + Sync {
126 async fn select(&self, context: Option<&RequestContext>) -> SentinelResult<TargetSelection>;
128
129 async fn report_health(&self, address: &str, healthy: bool);
131
132 async fn healthy_targets(&self) -> Vec<String>;
134
135 async fn release(&self, _selection: &TargetSelection) {
137 }
139
140 async fn report_result(
142 &self,
143 _selection: &TargetSelection,
144 _success: bool,
145 _latency: Option<Duration>,
146 ) {
147 }
149
150 async fn report_result_with_latency(
157 &self,
158 address: &str,
159 success: bool,
160 _latency: Option<Duration>,
161 ) {
162 self.report_health(address, success).await;
164 }
165}
166
167#[derive(Debug, Clone)]
169pub struct TargetSelection {
170 pub address: String,
172 pub weight: u32,
174 pub metadata: HashMap<String, String>,
176}
177
178pub struct UpstreamPool {
180 id: UpstreamId,
182 targets: Vec<UpstreamTarget>,
184 load_balancer: Arc<dyn LoadBalancer>,
186 pool_config: ConnectionPoolConfig,
188 http_version: HttpVersionOptions,
190 tls_enabled: bool,
192 tls_sni: Option<String>,
194 tls_config: Option<sentinel_config::UpstreamTlsConfig>,
196 circuit_breakers: Arc<RwLock<HashMap<String, CircuitBreaker>>>,
198 stats: Arc<PoolStats>,
200}
201
202pub struct ConnectionPoolConfig {
211 pub max_connections: usize,
213 pub max_idle: usize,
215 pub idle_timeout: Duration,
217 pub max_lifetime: Option<Duration>,
219 pub connection_timeout: Duration,
221 pub read_timeout: Duration,
223 pub write_timeout: Duration,
225}
226
227pub struct HttpVersionOptions {
229 pub min_version: u8,
231 pub max_version: u8,
233 pub h2_ping_interval: Duration,
235 pub max_h2_streams: usize,
237}
238
239impl ConnectionPoolConfig {
240 pub fn from_config(
242 pool_config: &sentinel_config::ConnectionPoolConfig,
243 timeouts: &sentinel_config::UpstreamTimeouts,
244 ) -> Self {
245 Self {
246 max_connections: pool_config.max_connections,
247 max_idle: pool_config.max_idle,
248 idle_timeout: Duration::from_secs(pool_config.idle_timeout_secs),
249 max_lifetime: pool_config.max_lifetime_secs.map(Duration::from_secs),
250 connection_timeout: Duration::from_secs(timeouts.connect_secs),
251 read_timeout: Duration::from_secs(timeouts.read_secs),
252 write_timeout: Duration::from_secs(timeouts.write_secs),
253 }
254 }
255}
256
257#[derive(Default)]
261pub struct PoolStats {
262 pub requests: AtomicU64,
264 pub successes: AtomicU64,
266 pub failures: AtomicU64,
268 pub retries: AtomicU64,
270 pub circuit_breaker_trips: AtomicU64,
272}
273
274#[derive(Debug, Clone)]
276pub struct ShadowTarget {
277 pub scheme: String,
279 pub host: String,
281 pub port: u16,
283 pub sni: Option<String>,
285}
286
287impl ShadowTarget {
288 pub fn build_url(&self, path: &str) -> String {
290 let port_suffix = match (self.scheme.as_str(), self.port) {
291 ("http", 80) | ("https", 443) => String::new(),
292 _ => format!(":{}", self.port),
293 };
294 format!("{}://{}{}{}", self.scheme, self.host, port_suffix, path)
295 }
296}
297
298#[derive(Debug, Clone)]
300pub struct PoolConfigSnapshot {
301 pub max_connections: usize,
303 pub max_idle: usize,
305 pub idle_timeout_secs: u64,
307 pub max_lifetime_secs: Option<u64>,
309 pub connection_timeout_secs: u64,
311 pub read_timeout_secs: u64,
313 pub write_timeout_secs: u64,
315}
316
317struct RoundRobinBalancer {
319 targets: Vec<UpstreamTarget>,
320 current: AtomicUsize,
321 health_status: Arc<RwLock<HashMap<String, bool>>>,
322}
323
324impl RoundRobinBalancer {
325 fn new(targets: Vec<UpstreamTarget>) -> Self {
326 let mut health_status = HashMap::new();
327 for target in &targets {
328 health_status.insert(target.full_address(), true);
329 }
330
331 Self {
332 targets,
333 current: AtomicUsize::new(0),
334 health_status: Arc::new(RwLock::new(health_status)),
335 }
336 }
337}
338
339#[async_trait]
340impl LoadBalancer for RoundRobinBalancer {
341 async fn select(&self, _context: Option<&RequestContext>) -> SentinelResult<TargetSelection> {
342 trace!(
343 total_targets = self.targets.len(),
344 algorithm = "round_robin",
345 "Selecting upstream target"
346 );
347
348 let health = self.health_status.read().await;
349 let healthy_targets: Vec<_> = self
350 .targets
351 .iter()
352 .filter(|t| *health.get(&t.full_address()).unwrap_or(&true))
353 .collect();
354
355 if healthy_targets.is_empty() {
356 warn!(
357 total_targets = self.targets.len(),
358 algorithm = "round_robin",
359 "No healthy upstream targets available"
360 );
361 return Err(SentinelError::NoHealthyUpstream);
362 }
363
364 let index = self.current.fetch_add(1, Ordering::Relaxed) % healthy_targets.len();
365 let target = healthy_targets[index];
366
367 trace!(
368 selected_target = %target.full_address(),
369 healthy_count = healthy_targets.len(),
370 index = index,
371 algorithm = "round_robin",
372 "Selected target via round robin"
373 );
374
375 Ok(TargetSelection {
376 address: target.full_address(),
377 weight: target.weight,
378 metadata: HashMap::new(),
379 })
380 }
381
382 async fn report_health(&self, address: &str, healthy: bool) {
383 trace!(
384 target = %address,
385 healthy = healthy,
386 algorithm = "round_robin",
387 "Updating target health status"
388 );
389 self.health_status
390 .write()
391 .await
392 .insert(address.to_string(), healthy);
393 }
394
395 async fn healthy_targets(&self) -> Vec<String> {
396 self.health_status
397 .read()
398 .await
399 .iter()
400 .filter_map(|(addr, &healthy)| if healthy { Some(addr.clone()) } else { None })
401 .collect()
402 }
403}
404
405struct RandomBalancer {
407 targets: Vec<UpstreamTarget>,
408 health_status: Arc<RwLock<HashMap<String, bool>>>,
409}
410
411impl RandomBalancer {
412 fn new(targets: Vec<UpstreamTarget>) -> Self {
413 let mut health_status = HashMap::new();
414 for target in &targets {
415 health_status.insert(target.full_address(), true);
416 }
417
418 Self {
419 targets,
420 health_status: Arc::new(RwLock::new(health_status)),
421 }
422 }
423}
424
425#[async_trait]
426impl LoadBalancer for RandomBalancer {
427 async fn select(&self, _context: Option<&RequestContext>) -> SentinelResult<TargetSelection> {
428 use rand::seq::SliceRandom;
429
430 trace!(
431 total_targets = self.targets.len(),
432 algorithm = "random",
433 "Selecting upstream target"
434 );
435
436 let health = self.health_status.read().await;
437 let healthy_targets: Vec<_> = self
438 .targets
439 .iter()
440 .filter(|t| *health.get(&t.full_address()).unwrap_or(&true))
441 .collect();
442
443 if healthy_targets.is_empty() {
444 warn!(
445 total_targets = self.targets.len(),
446 algorithm = "random",
447 "No healthy upstream targets available"
448 );
449 return Err(SentinelError::NoHealthyUpstream);
450 }
451
452 let mut rng = rand::rng();
453 let target = healthy_targets
454 .choose(&mut rng)
455 .ok_or(SentinelError::NoHealthyUpstream)?;
456
457 trace!(
458 selected_target = %target.full_address(),
459 healthy_count = healthy_targets.len(),
460 algorithm = "random",
461 "Selected target via random selection"
462 );
463
464 Ok(TargetSelection {
465 address: target.full_address(),
466 weight: target.weight,
467 metadata: HashMap::new(),
468 })
469 }
470
471 async fn report_health(&self, address: &str, healthy: bool) {
472 trace!(
473 target = %address,
474 healthy = healthy,
475 algorithm = "random",
476 "Updating target health status"
477 );
478 self.health_status
479 .write()
480 .await
481 .insert(address.to_string(), healthy);
482 }
483
484 async fn healthy_targets(&self) -> Vec<String> {
485 self.health_status
486 .read()
487 .await
488 .iter()
489 .filter_map(|(addr, &healthy)| if healthy { Some(addr.clone()) } else { None })
490 .collect()
491 }
492}
493
494struct LeastConnectionsBalancer {
496 targets: Vec<UpstreamTarget>,
497 connections: Arc<RwLock<HashMap<String, usize>>>,
498 health_status: Arc<RwLock<HashMap<String, bool>>>,
499}
500
501impl LeastConnectionsBalancer {
502 fn new(targets: Vec<UpstreamTarget>) -> Self {
503 let mut health_status = HashMap::new();
504 let mut connections = HashMap::new();
505
506 for target in &targets {
507 let addr = target.full_address();
508 health_status.insert(addr.clone(), true);
509 connections.insert(addr, 0);
510 }
511
512 Self {
513 targets,
514 connections: Arc::new(RwLock::new(connections)),
515 health_status: Arc::new(RwLock::new(health_status)),
516 }
517 }
518}
519
520#[async_trait]
521impl LoadBalancer for LeastConnectionsBalancer {
522 async fn select(&self, _context: Option<&RequestContext>) -> SentinelResult<TargetSelection> {
523 trace!(
524 total_targets = self.targets.len(),
525 algorithm = "least_connections",
526 "Selecting upstream target"
527 );
528
529 let health = self.health_status.read().await;
530 let conns = self.connections.read().await;
531
532 let mut best_target = None;
533 let mut min_connections = usize::MAX;
534
535 for target in &self.targets {
536 let addr = target.full_address();
537 if !*health.get(&addr).unwrap_or(&true) {
538 trace!(
539 target = %addr,
540 algorithm = "least_connections",
541 "Skipping unhealthy target"
542 );
543 continue;
544 }
545
546 let conn_count = *conns.get(&addr).unwrap_or(&0);
547 trace!(
548 target = %addr,
549 connections = conn_count,
550 "Evaluating target connection count"
551 );
552 if conn_count < min_connections {
553 min_connections = conn_count;
554 best_target = Some(target);
555 }
556 }
557
558 match best_target {
559 Some(target) => {
560 trace!(
561 selected_target = %target.full_address(),
562 connections = min_connections,
563 algorithm = "least_connections",
564 "Selected target with fewest connections"
565 );
566 Ok(TargetSelection {
567 address: target.full_address(),
568 weight: target.weight,
569 metadata: HashMap::new(),
570 })
571 }
572 None => {
573 warn!(
574 total_targets = self.targets.len(),
575 algorithm = "least_connections",
576 "No healthy upstream targets available"
577 );
578 Err(SentinelError::NoHealthyUpstream)
579 }
580 }
581 }
582
583 async fn report_health(&self, address: &str, healthy: bool) {
584 trace!(
585 target = %address,
586 healthy = healthy,
587 algorithm = "least_connections",
588 "Updating target health status"
589 );
590 self.health_status
591 .write()
592 .await
593 .insert(address.to_string(), healthy);
594 }
595
596 async fn healthy_targets(&self) -> Vec<String> {
597 self.health_status
598 .read()
599 .await
600 .iter()
601 .filter_map(|(addr, &healthy)| if healthy { Some(addr.clone()) } else { None })
602 .collect()
603 }
604}
605
606struct WeightedBalancer {
608 targets: Vec<UpstreamTarget>,
609 weights: Vec<u32>,
610 current_index: AtomicUsize,
611 health_status: Arc<RwLock<HashMap<String, bool>>>,
612}
613
614#[async_trait]
615impl LoadBalancer for WeightedBalancer {
616 async fn select(&self, _context: Option<&RequestContext>) -> SentinelResult<TargetSelection> {
617 trace!(
618 total_targets = self.targets.len(),
619 algorithm = "weighted",
620 "Selecting upstream target"
621 );
622
623 let health = self.health_status.read().await;
624 let healthy_indices: Vec<_> = self
625 .targets
626 .iter()
627 .enumerate()
628 .filter(|(_, t)| *health.get(&t.full_address()).unwrap_or(&true))
629 .map(|(i, _)| i)
630 .collect();
631
632 if healthy_indices.is_empty() {
633 warn!(
634 total_targets = self.targets.len(),
635 algorithm = "weighted",
636 "No healthy upstream targets available"
637 );
638 return Err(SentinelError::NoHealthyUpstream);
639 }
640
641 let idx = self.current_index.fetch_add(1, Ordering::Relaxed) % healthy_indices.len();
642 let target_idx = healthy_indices[idx];
643 let target = &self.targets[target_idx];
644 let weight = self.weights.get(target_idx).copied().unwrap_or(1);
645
646 trace!(
647 selected_target = %target.full_address(),
648 weight = weight,
649 healthy_count = healthy_indices.len(),
650 algorithm = "weighted",
651 "Selected target via weighted round robin"
652 );
653
654 Ok(TargetSelection {
655 address: target.full_address(),
656 weight,
657 metadata: HashMap::new(),
658 })
659 }
660
661 async fn report_health(&self, address: &str, healthy: bool) {
662 trace!(
663 target = %address,
664 healthy = healthy,
665 algorithm = "weighted",
666 "Updating target health status"
667 );
668 self.health_status
669 .write()
670 .await
671 .insert(address.to_string(), healthy);
672 }
673
674 async fn healthy_targets(&self) -> Vec<String> {
675 self.health_status
676 .read()
677 .await
678 .iter()
679 .filter_map(|(addr, &healthy)| if healthy { Some(addr.clone()) } else { None })
680 .collect()
681 }
682}
683
684struct IpHashBalancer {
686 targets: Vec<UpstreamTarget>,
687 health_status: Arc<RwLock<HashMap<String, bool>>>,
688}
689
690#[async_trait]
691impl LoadBalancer for IpHashBalancer {
692 async fn select(&self, context: Option<&RequestContext>) -> SentinelResult<TargetSelection> {
693 trace!(
694 total_targets = self.targets.len(),
695 algorithm = "ip_hash",
696 "Selecting upstream target"
697 );
698
699 let health = self.health_status.read().await;
700 let healthy_targets: Vec<_> = self
701 .targets
702 .iter()
703 .filter(|t| *health.get(&t.full_address()).unwrap_or(&true))
704 .collect();
705
706 if healthy_targets.is_empty() {
707 warn!(
708 total_targets = self.targets.len(),
709 algorithm = "ip_hash",
710 "No healthy upstream targets available"
711 );
712 return Err(SentinelError::NoHealthyUpstream);
713 }
714
715 let (hash, client_ip_str) = if let Some(ctx) = context {
717 if let Some(ip) = &ctx.client_ip {
718 use std::hash::{Hash, Hasher};
719 let mut hasher = std::collections::hash_map::DefaultHasher::new();
720 ip.hash(&mut hasher);
721 (hasher.finish(), Some(ip.to_string()))
722 } else {
723 (0, None)
724 }
725 } else {
726 (0, None)
727 };
728
729 let idx = (hash as usize) % healthy_targets.len();
730 let target = healthy_targets[idx];
731
732 trace!(
733 selected_target = %target.full_address(),
734 client_ip = client_ip_str.as_deref().unwrap_or("unknown"),
735 hash = hash,
736 index = idx,
737 healthy_count = healthy_targets.len(),
738 algorithm = "ip_hash",
739 "Selected target via IP hash"
740 );
741
742 Ok(TargetSelection {
743 address: target.full_address(),
744 weight: target.weight,
745 metadata: HashMap::new(),
746 })
747 }
748
749 async fn report_health(&self, address: &str, healthy: bool) {
750 trace!(
751 target = %address,
752 healthy = healthy,
753 algorithm = "ip_hash",
754 "Updating target health status"
755 );
756 self.health_status
757 .write()
758 .await
759 .insert(address.to_string(), healthy);
760 }
761
762 async fn healthy_targets(&self) -> Vec<String> {
763 self.health_status
764 .read()
765 .await
766 .iter()
767 .filter_map(|(addr, &healthy)| if healthy { Some(addr.clone()) } else { None })
768 .collect()
769 }
770}
771
772impl UpstreamPool {
773 pub async fn new(config: UpstreamConfig) -> SentinelResult<Self> {
775 let id = UpstreamId::new(&config.id);
776
777 info!(
778 upstream_id = %config.id,
779 target_count = config.targets.len(),
780 algorithm = ?config.load_balancing,
781 "Creating upstream pool"
782 );
783
784 let targets: Vec<UpstreamTarget> = config
786 .targets
787 .iter()
788 .filter_map(UpstreamTarget::from_config)
789 .collect();
790
791 if targets.is_empty() {
792 error!(
793 upstream_id = %config.id,
794 "No valid upstream targets configured"
795 );
796 return Err(SentinelError::Config {
797 message: "No valid upstream targets".to_string(),
798 source: None,
799 });
800 }
801
802 for target in &targets {
803 debug!(
804 upstream_id = %config.id,
805 target = %target.full_address(),
806 weight = target.weight,
807 "Registered upstream target"
808 );
809 }
810
811 debug!(
813 upstream_id = %config.id,
814 algorithm = ?config.load_balancing,
815 "Creating load balancer"
816 );
817 let load_balancer =
818 Self::create_load_balancer(&config.load_balancing, &targets, &config)?;
819
820 debug!(
822 upstream_id = %config.id,
823 max_connections = config.connection_pool.max_connections,
824 max_idle = config.connection_pool.max_idle,
825 idle_timeout_secs = config.connection_pool.idle_timeout_secs,
826 connect_timeout_secs = config.timeouts.connect_secs,
827 read_timeout_secs = config.timeouts.read_secs,
828 write_timeout_secs = config.timeouts.write_secs,
829 "Creating connection pool configuration"
830 );
831 let pool_config =
832 ConnectionPoolConfig::from_config(&config.connection_pool, &config.timeouts);
833
834 let http_version = HttpVersionOptions {
836 min_version: config.http_version.min_version,
837 max_version: config.http_version.max_version,
838 h2_ping_interval: if config.http_version.h2_ping_interval_secs > 0 {
839 Duration::from_secs(config.http_version.h2_ping_interval_secs)
840 } else {
841 Duration::ZERO
842 },
843 max_h2_streams: config.http_version.max_h2_streams,
844 };
845
846 let tls_enabled = config.tls.is_some();
848 let tls_sni = config.tls.as_ref().and_then(|t| t.sni.clone());
849 let tls_config = config.tls.clone();
850
851 if let Some(ref tls) = tls_config {
853 if tls.client_cert.is_some() {
854 info!(
855 upstream_id = %config.id,
856 "mTLS enabled for upstream (client certificate configured)"
857 );
858 }
859 }
860
861 if http_version.max_version >= 2 && tls_enabled {
862 info!(
863 upstream_id = %config.id,
864 "HTTP/2 enabled for upstream (via ALPN)"
865 );
866 }
867
868 let mut circuit_breakers = HashMap::new();
870 for target in &targets {
871 trace!(
872 upstream_id = %config.id,
873 target = %target.full_address(),
874 "Initializing circuit breaker for target"
875 );
876 circuit_breakers.insert(
877 target.full_address(),
878 CircuitBreaker::new(CircuitBreakerConfig::default()),
879 );
880 }
881
882 let pool = Self {
883 id: id.clone(),
884 targets,
885 load_balancer,
886 pool_config,
887 http_version,
888 tls_enabled,
889 tls_sni,
890 tls_config,
891 circuit_breakers: Arc::new(RwLock::new(circuit_breakers)),
892 stats: Arc::new(PoolStats::default()),
893 };
894
895 info!(
896 upstream_id = %id,
897 target_count = pool.targets.len(),
898 "Upstream pool created successfully"
899 );
900
901 Ok(pool)
902 }
903
904 fn create_load_balancer(
906 algorithm: &LoadBalancingAlgorithm,
907 targets: &[UpstreamTarget],
908 config: &UpstreamConfig,
909 ) -> SentinelResult<Arc<dyn LoadBalancer>> {
910 let balancer: Arc<dyn LoadBalancer> = match algorithm {
911 LoadBalancingAlgorithm::RoundRobin => {
912 Arc::new(RoundRobinBalancer::new(targets.to_vec()))
913 }
914 LoadBalancingAlgorithm::LeastConnections => {
915 Arc::new(LeastConnectionsBalancer::new(targets.to_vec()))
916 }
917 LoadBalancingAlgorithm::Weighted => {
918 let weights: Vec<u32> = targets.iter().map(|t| t.weight).collect();
919 Arc::new(WeightedBalancer {
920 targets: targets.to_vec(),
921 weights,
922 current_index: AtomicUsize::new(0),
923 health_status: Arc::new(RwLock::new(HashMap::new())),
924 })
925 }
926 LoadBalancingAlgorithm::IpHash => Arc::new(IpHashBalancer {
927 targets: targets.to_vec(),
928 health_status: Arc::new(RwLock::new(HashMap::new())),
929 }),
930 LoadBalancingAlgorithm::Random => {
931 Arc::new(RandomBalancer::new(targets.to_vec()))
932 }
933 LoadBalancingAlgorithm::ConsistentHash => Arc::new(ConsistentHashBalancer::new(
934 targets.to_vec(),
935 ConsistentHashConfig::default(),
936 )),
937 LoadBalancingAlgorithm::PowerOfTwoChoices => {
938 Arc::new(P2cBalancer::new(targets.to_vec(), P2cConfig::default()))
939 }
940 LoadBalancingAlgorithm::Adaptive => Arc::new(AdaptiveBalancer::new(
941 targets.to_vec(),
942 AdaptiveConfig::default(),
943 )),
944 LoadBalancingAlgorithm::LeastTokensQueued => Arc::new(LeastTokensQueuedBalancer::new(
945 targets.to_vec(),
946 LeastTokensQueuedConfig::default(),
947 )),
948 LoadBalancingAlgorithm::Maglev => Arc::new(MaglevBalancer::new(
949 targets.to_vec(),
950 MaglevConfig::default(),
951 )),
952 LoadBalancingAlgorithm::LocalityAware => Arc::new(LocalityAwareBalancer::new(
953 targets.to_vec(),
954 LocalityAwareConfig::default(),
955 )),
956 LoadBalancingAlgorithm::PeakEwma => Arc::new(PeakEwmaBalancer::new(
957 targets.to_vec(),
958 PeakEwmaConfig::default(),
959 )),
960 LoadBalancingAlgorithm::DeterministicSubset => Arc::new(SubsetBalancer::new(
961 targets.to_vec(),
962 SubsetConfig::default(),
963 )),
964 LoadBalancingAlgorithm::WeightedLeastConnections => {
965 Arc::new(WeightedLeastConnBalancer::new(
966 targets.to_vec(),
967 WeightedLeastConnConfig::default(),
968 ))
969 }
970 LoadBalancingAlgorithm::Sticky => {
971 let sticky_config = config.sticky_session.as_ref().ok_or_else(|| {
973 SentinelError::Config {
974 message: format!(
975 "Upstream '{}' uses Sticky algorithm but no sticky_session config provided",
976 config.id
977 ),
978 source: None,
979 }
980 })?;
981
982 let runtime_config = StickySessionRuntimeConfig::from_config(sticky_config);
984
985 let fallback = Self::create_load_balancer_inner(&sticky_config.fallback, targets)?;
987
988 info!(
989 upstream_id = %config.id,
990 cookie_name = %runtime_config.cookie_name,
991 cookie_ttl_secs = runtime_config.cookie_ttl_secs,
992 fallback_algorithm = ?sticky_config.fallback,
993 "Creating sticky session balancer"
994 );
995
996 Arc::new(StickySessionBalancer::new(
997 targets.to_vec(),
998 runtime_config,
999 fallback,
1000 ))
1001 }
1002 };
1003 Ok(balancer)
1004 }
1005
1006 fn create_load_balancer_inner(
1008 algorithm: &LoadBalancingAlgorithm,
1009 targets: &[UpstreamTarget],
1010 ) -> SentinelResult<Arc<dyn LoadBalancer>> {
1011 let balancer: Arc<dyn LoadBalancer> = match algorithm {
1012 LoadBalancingAlgorithm::RoundRobin => {
1013 Arc::new(RoundRobinBalancer::new(targets.to_vec()))
1014 }
1015 LoadBalancingAlgorithm::LeastConnections => {
1016 Arc::new(LeastConnectionsBalancer::new(targets.to_vec()))
1017 }
1018 LoadBalancingAlgorithm::Weighted => {
1019 let weights: Vec<u32> = targets.iter().map(|t| t.weight).collect();
1020 Arc::new(WeightedBalancer {
1021 targets: targets.to_vec(),
1022 weights,
1023 current_index: AtomicUsize::new(0),
1024 health_status: Arc::new(RwLock::new(HashMap::new())),
1025 })
1026 }
1027 LoadBalancingAlgorithm::IpHash => Arc::new(IpHashBalancer {
1028 targets: targets.to_vec(),
1029 health_status: Arc::new(RwLock::new(HashMap::new())),
1030 }),
1031 LoadBalancingAlgorithm::Random => {
1032 Arc::new(RandomBalancer::new(targets.to_vec()))
1033 }
1034 LoadBalancingAlgorithm::ConsistentHash => Arc::new(ConsistentHashBalancer::new(
1035 targets.to_vec(),
1036 ConsistentHashConfig::default(),
1037 )),
1038 LoadBalancingAlgorithm::PowerOfTwoChoices => {
1039 Arc::new(P2cBalancer::new(targets.to_vec(), P2cConfig::default()))
1040 }
1041 LoadBalancingAlgorithm::Adaptive => Arc::new(AdaptiveBalancer::new(
1042 targets.to_vec(),
1043 AdaptiveConfig::default(),
1044 )),
1045 LoadBalancingAlgorithm::LeastTokensQueued => Arc::new(LeastTokensQueuedBalancer::new(
1046 targets.to_vec(),
1047 LeastTokensQueuedConfig::default(),
1048 )),
1049 LoadBalancingAlgorithm::Maglev => Arc::new(MaglevBalancer::new(
1050 targets.to_vec(),
1051 MaglevConfig::default(),
1052 )),
1053 LoadBalancingAlgorithm::LocalityAware => Arc::new(LocalityAwareBalancer::new(
1054 targets.to_vec(),
1055 LocalityAwareConfig::default(),
1056 )),
1057 LoadBalancingAlgorithm::PeakEwma => Arc::new(PeakEwmaBalancer::new(
1058 targets.to_vec(),
1059 PeakEwmaConfig::default(),
1060 )),
1061 LoadBalancingAlgorithm::DeterministicSubset => Arc::new(SubsetBalancer::new(
1062 targets.to_vec(),
1063 SubsetConfig::default(),
1064 )),
1065 LoadBalancingAlgorithm::WeightedLeastConnections => {
1066 Arc::new(WeightedLeastConnBalancer::new(
1067 targets.to_vec(),
1068 WeightedLeastConnConfig::default(),
1069 ))
1070 }
1071 LoadBalancingAlgorithm::Sticky => {
1072 return Err(SentinelError::Config {
1074 message: "Sticky algorithm cannot be used as fallback for sticky sessions"
1075 .to_string(),
1076 source: None,
1077 });
1078 }
1079 };
1080 Ok(balancer)
1081 }
1082
1083 pub async fn select_peer_with_metadata(
1089 &self,
1090 context: Option<&RequestContext>,
1091 ) -> SentinelResult<(HttpPeer, HashMap<String, String>)> {
1092 let request_num = self.stats.requests.fetch_add(1, Ordering::Relaxed) + 1;
1093
1094 trace!(
1095 upstream_id = %self.id,
1096 request_num = request_num,
1097 target_count = self.targets.len(),
1098 "Starting peer selection with metadata"
1099 );
1100
1101 let mut attempts = 0;
1102 let max_attempts = self.targets.len() * 2;
1103
1104 while attempts < max_attempts {
1105 attempts += 1;
1106
1107 trace!(
1108 upstream_id = %self.id,
1109 attempt = attempts,
1110 max_attempts = max_attempts,
1111 "Attempting to select peer"
1112 );
1113
1114 let selection = match self.load_balancer.select(context).await {
1115 Ok(s) => s,
1116 Err(e) => {
1117 warn!(
1118 upstream_id = %self.id,
1119 attempt = attempts,
1120 error = %e,
1121 "Load balancer selection failed"
1122 );
1123 continue;
1124 }
1125 };
1126
1127 trace!(
1128 upstream_id = %self.id,
1129 target = %selection.address,
1130 attempt = attempts,
1131 "Load balancer selected target"
1132 );
1133
1134 let breakers = self.circuit_breakers.read().await;
1136 if let Some(breaker) = breakers.get(&selection.address) {
1137 if !breaker.is_closed() {
1138 debug!(
1139 upstream_id = %self.id,
1140 target = %selection.address,
1141 attempt = attempts,
1142 "Circuit breaker is open, skipping target"
1143 );
1144 self.stats
1145 .circuit_breaker_trips
1146 .fetch_add(1, Ordering::Relaxed);
1147 continue;
1148 }
1149 }
1150
1151 trace!(
1153 upstream_id = %self.id,
1154 target = %selection.address,
1155 "Creating peer for upstream (Pingora handles connection reuse)"
1156 );
1157 let peer = self.create_peer(&selection)?;
1158
1159 debug!(
1160 upstream_id = %self.id,
1161 target = %selection.address,
1162 attempt = attempts,
1163 metadata_keys = ?selection.metadata.keys().collect::<Vec<_>>(),
1164 "Selected upstream peer with metadata"
1165 );
1166
1167 self.stats.successes.fetch_add(1, Ordering::Relaxed);
1168 return Ok((peer, selection.metadata));
1169 }
1170
1171 self.stats.failures.fetch_add(1, Ordering::Relaxed);
1172 error!(
1173 upstream_id = %self.id,
1174 attempts = attempts,
1175 max_attempts = max_attempts,
1176 "Failed to select upstream after max attempts"
1177 );
1178 Err(SentinelError::upstream(
1179 self.id.to_string(),
1180 "Failed to select upstream after max attempts",
1181 ))
1182 }
1183
1184 pub async fn select_peer(&self, context: Option<&RequestContext>) -> SentinelResult<HttpPeer> {
1186 self.select_peer_with_metadata(context)
1188 .await
1189 .map(|(peer, _)| peer)
1190 }
1191
1192 fn create_peer(&self, selection: &TargetSelection) -> SentinelResult<HttpPeer> {
1198 let sni_hostname = self.tls_sni.clone().unwrap_or_else(|| {
1200 selection
1202 .address
1203 .split(':')
1204 .next()
1205 .unwrap_or(&selection.address)
1206 .to_string()
1207 });
1208
1209 let resolved_address = selection
1212 .address
1213 .to_socket_addrs()
1214 .map_err(|e| {
1215 error!(
1216 upstream = %self.id,
1217 address = %selection.address,
1218 error = %e,
1219 "Failed to resolve upstream address"
1220 );
1221 SentinelError::Upstream {
1222 upstream: self.id.to_string(),
1223 message: format!("DNS resolution failed for {}: {}", selection.address, e),
1224 retryable: true,
1225 source: None,
1226 }
1227 })?
1228 .next()
1229 .ok_or_else(|| {
1230 error!(
1231 upstream = %self.id,
1232 address = %selection.address,
1233 "No addresses returned from DNS resolution"
1234 );
1235 SentinelError::Upstream {
1236 upstream: self.id.to_string(),
1237 message: format!("No addresses for {}", selection.address),
1238 retryable: true,
1239 source: None,
1240 }
1241 })?;
1242
1243 let mut peer = HttpPeer::new(resolved_address, self.tls_enabled, sni_hostname.clone());
1245
1246 peer.options.idle_timeout = Some(self.pool_config.idle_timeout);
1250
1251 peer.options.connection_timeout = Some(self.pool_config.connection_timeout);
1253 peer.options.total_connection_timeout = Some(Duration::from_secs(10));
1254
1255 peer.options.read_timeout = Some(self.pool_config.read_timeout);
1257 peer.options.write_timeout = Some(self.pool_config.write_timeout);
1258
1259 peer.options.tcp_keepalive = Some(pingora::protocols::TcpKeepalive {
1261 idle: Duration::from_secs(60),
1262 interval: Duration::from_secs(10),
1263 count: 3,
1264 #[cfg(target_os = "linux")]
1266 user_timeout: Duration::from_secs(60),
1267 });
1268
1269 if self.tls_enabled {
1271 let alpn = match (self.http_version.min_version, self.http_version.max_version) {
1273 (2, _) => {
1274 pingora::upstreams::peer::ALPN::H2
1276 }
1277 (1, 2) | (_, 2) => {
1278 pingora::upstreams::peer::ALPN::H2H1
1280 }
1281 _ => {
1282 pingora::upstreams::peer::ALPN::H1
1284 }
1285 };
1286 peer.options.alpn = alpn;
1287
1288 if let Some(ref tls_config) = self.tls_config {
1290 if tls_config.insecure_skip_verify {
1292 peer.options.verify_cert = false;
1293 peer.options.verify_hostname = false;
1294 warn!(
1295 upstream_id = %self.id,
1296 target = %selection.address,
1297 "TLS certificate verification DISABLED (insecure_skip_verify=true)"
1298 );
1299 }
1300
1301 if let Some(ref sni) = tls_config.sni {
1303 peer.options.alternative_cn = Some(sni.clone());
1304 trace!(
1305 upstream_id = %self.id,
1306 target = %selection.address,
1307 alternative_cn = %sni,
1308 "Set alternative CN for TLS verification"
1309 );
1310 }
1311
1312 if let (Some(cert_path), Some(key_path)) =
1314 (&tls_config.client_cert, &tls_config.client_key)
1315 {
1316 match crate::tls::load_client_cert_key(cert_path, key_path) {
1317 Ok(cert_key) => {
1318 peer.client_cert_key = Some(cert_key);
1319 info!(
1320 upstream_id = %self.id,
1321 target = %selection.address,
1322 cert_path = ?cert_path,
1323 "mTLS client certificate configured"
1324 );
1325 }
1326 Err(e) => {
1327 error!(
1328 upstream_id = %self.id,
1329 target = %selection.address,
1330 error = %e,
1331 "Failed to load mTLS client certificate"
1332 );
1333 return Err(SentinelError::Tls {
1334 message: format!("Failed to load client certificate: {}", e),
1335 source: None,
1336 });
1337 }
1338 }
1339 }
1340 }
1341
1342 trace!(
1343 upstream_id = %self.id,
1344 target = %selection.address,
1345 alpn = ?peer.options.alpn,
1346 min_version = self.http_version.min_version,
1347 max_version = self.http_version.max_version,
1348 verify_cert = peer.options.verify_cert,
1349 verify_hostname = peer.options.verify_hostname,
1350 "Configured ALPN and TLS options for HTTP version negotiation"
1351 );
1352 }
1353
1354 if self.http_version.max_version >= 2 {
1356 if !self.http_version.h2_ping_interval.is_zero() {
1358 peer.options.h2_ping_interval = Some(self.http_version.h2_ping_interval);
1359 trace!(
1360 upstream_id = %self.id,
1361 target = %selection.address,
1362 h2_ping_interval_secs = self.http_version.h2_ping_interval.as_secs(),
1363 "Configured H2 ping interval"
1364 );
1365 }
1366 }
1367
1368 trace!(
1369 upstream_id = %self.id,
1370 target = %selection.address,
1371 tls = self.tls_enabled,
1372 sni = %sni_hostname,
1373 idle_timeout_secs = self.pool_config.idle_timeout.as_secs(),
1374 http_max_version = self.http_version.max_version,
1375 "Created peer with Pingora connection pooling enabled"
1376 );
1377
1378 Ok(peer)
1379 }
1380
1381 pub async fn report_result(&self, target: &str, success: bool) {
1388 trace!(
1389 upstream_id = %self.id,
1390 target = %target,
1391 success = success,
1392 "Reporting connection result"
1393 );
1394
1395 if success {
1396 if let Some(breaker) = self.circuit_breakers.read().await.get(target) {
1397 breaker.record_success();
1398 trace!(
1399 upstream_id = %self.id,
1400 target = %target,
1401 "Recorded success in circuit breaker"
1402 );
1403 }
1404 self.load_balancer.report_health(target, true).await;
1405 } else {
1406 let breaker_opened =
1407 if let Some(breaker) = self.circuit_breakers.read().await.get(target) {
1408 let opened = breaker.record_failure();
1409 debug!(
1410 upstream_id = %self.id,
1411 target = %target,
1412 circuit_breaker_opened = opened,
1413 "Recorded failure in circuit breaker"
1414 );
1415 opened
1416 } else {
1417 false
1418 };
1419
1420 if breaker_opened {
1425 self.load_balancer.report_health(target, false).await;
1426 }
1427
1428 self.stats.failures.fetch_add(1, Ordering::Relaxed);
1429 warn!(
1430 upstream_id = %self.id,
1431 target = %target,
1432 circuit_breaker_opened = breaker_opened,
1433 "Connection failure reported for target"
1434 );
1435 }
1436 }
1437
1438 pub async fn report_result_with_latency(
1446 &self,
1447 target: &str,
1448 success: bool,
1449 latency: Option<Duration>,
1450 ) {
1451 trace!(
1452 upstream_id = %self.id,
1453 target = %target,
1454 success = success,
1455 latency_ms = latency.map(|l| l.as_millis() as u64),
1456 "Reporting result with latency for adaptive LB"
1457 );
1458
1459 if success {
1461 if let Some(breaker) = self.circuit_breakers.read().await.get(target) {
1462 breaker.record_success();
1463 }
1464 self.load_balancer
1466 .report_result_with_latency(target, true, latency)
1467 .await;
1468 } else {
1469 let breaker_opened =
1470 if let Some(breaker) = self.circuit_breakers.read().await.get(target) {
1471 breaker.record_failure()
1472 } else {
1473 false
1474 };
1475 self.stats.failures.fetch_add(1, Ordering::Relaxed);
1476
1477 if breaker_opened {
1482 self.load_balancer
1483 .report_result_with_latency(target, false, latency)
1484 .await;
1485 }
1486 }
1487 }
1488
1489 pub fn stats(&self) -> &PoolStats {
1491 &self.stats
1492 }
1493
1494 pub fn id(&self) -> &UpstreamId {
1496 &self.id
1497 }
1498
1499 pub fn target_count(&self) -> usize {
1501 self.targets.len()
1502 }
1503
1504 pub fn pool_config(&self) -> PoolConfigSnapshot {
1506 PoolConfigSnapshot {
1507 max_connections: self.pool_config.max_connections,
1508 max_idle: self.pool_config.max_idle,
1509 idle_timeout_secs: self.pool_config.idle_timeout.as_secs(),
1510 max_lifetime_secs: self.pool_config.max_lifetime.map(|d| d.as_secs()),
1511 connection_timeout_secs: self.pool_config.connection_timeout.as_secs(),
1512 read_timeout_secs: self.pool_config.read_timeout.as_secs(),
1513 write_timeout_secs: self.pool_config.write_timeout.as_secs(),
1514 }
1515 }
1516
1517 pub async fn has_healthy_targets(&self) -> bool {
1521 let healthy = self.load_balancer.healthy_targets().await;
1522 !healthy.is_empty()
1523 }
1524
1525 pub async fn select_shadow_target(
1530 &self,
1531 context: Option<&RequestContext>,
1532 ) -> SentinelResult<ShadowTarget> {
1533 let selection = self.load_balancer.select(context).await?;
1535
1536 let breakers = self.circuit_breakers.read().await;
1538 if let Some(breaker) = breakers.get(&selection.address) {
1539 if !breaker.is_closed() {
1540 return Err(SentinelError::upstream(
1541 self.id.to_string(),
1542 "Circuit breaker is open for shadow target",
1543 ));
1544 }
1545 }
1546
1547 let (host, port) = if selection.address.contains(':') {
1549 let parts: Vec<&str> = selection.address.rsplitn(2, ':').collect();
1550 if parts.len() == 2 {
1551 (
1552 parts[1].to_string(),
1553 parts[0].parse::<u16>().unwrap_or(if self.tls_enabled { 443 } else { 80 }),
1554 )
1555 } else {
1556 (selection.address.clone(), if self.tls_enabled { 443 } else { 80 })
1557 }
1558 } else {
1559 (selection.address.clone(), if self.tls_enabled { 443 } else { 80 })
1560 };
1561
1562 Ok(ShadowTarget {
1563 scheme: if self.tls_enabled { "https" } else { "http" }.to_string(),
1564 host,
1565 port,
1566 sni: self.tls_sni.clone(),
1567 })
1568 }
1569
1570 pub fn is_tls_enabled(&self) -> bool {
1572 self.tls_enabled
1573 }
1574
1575 pub async fn shutdown(&self) {
1579 info!(
1580 upstream_id = %self.id,
1581 target_count = self.targets.len(),
1582 total_requests = self.stats.requests.load(Ordering::Relaxed),
1583 total_successes = self.stats.successes.load(Ordering::Relaxed),
1584 total_failures = self.stats.failures.load(Ordering::Relaxed),
1585 "Shutting down upstream pool"
1586 );
1587 debug!(upstream_id = %self.id, "Upstream pool shutdown complete");
1589 }
1590}