1use async_trait::async_trait;
7use pingora::upstreams::peer::HttpPeer;
8use std::collections::HashMap;
9use std::net::ToSocketAddrs;
10use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
11use std::sync::Arc;
12use std::time::Duration;
13use tokio::sync::RwLock;
14use tracing::{debug, error, info, trace, warn};
15
16use sentinel_common::{
17 errors::{SentinelError, SentinelResult},
18 types::{CircuitBreakerConfig, LoadBalancingAlgorithm},
19 CircuitBreaker, UpstreamId,
20};
21use sentinel_config::UpstreamConfig;
22
23#[derive(Debug, Clone)]
32pub struct UpstreamTarget {
33 pub address: String,
35 pub port: u16,
37 pub weight: u32,
39}
40
41impl UpstreamTarget {
42 pub fn new(address: impl Into<String>, port: u16, weight: u32) -> Self {
44 Self {
45 address: address.into(),
46 port,
47 weight,
48 }
49 }
50
51 pub fn from_address(addr: &str) -> Option<Self> {
53 let parts: Vec<&str> = addr.rsplitn(2, ':').collect();
54 if parts.len() == 2 {
55 let port = parts[0].parse().ok()?;
56 let address = parts[1].to_string();
57 Some(Self {
58 address,
59 port,
60 weight: 100,
61 })
62 } else {
63 None
64 }
65 }
66
67 pub fn from_config(config: &sentinel_config::UpstreamTarget) -> Option<Self> {
69 Self::from_address(&config.address).map(|mut t| {
70 t.weight = config.weight;
71 t
72 })
73 }
74
75 pub fn full_address(&self) -> String {
77 format!("{}:{}", self.address, self.port)
78 }
79}
80
81pub mod adaptive;
87pub mod consistent_hash;
88pub mod health;
89pub mod inference_health;
90pub mod least_tokens;
91pub mod locality;
92pub mod maglev;
93pub mod p2c;
94pub mod peak_ewma;
95pub mod sticky_session;
96pub mod subset;
97pub mod weighted_least_conn;
98
99pub use adaptive::{AdaptiveBalancer, AdaptiveConfig};
101pub use consistent_hash::{ConsistentHashBalancer, ConsistentHashConfig};
102pub use health::{ActiveHealthChecker, HealthCheckRunner};
103pub use inference_health::InferenceHealthCheck;
104pub use least_tokens::{LeastTokensQueuedBalancer, LeastTokensQueuedConfig, LeastTokensQueuedTargetStats};
105pub use locality::{LocalityAwareBalancer, LocalityAwareConfig};
106pub use maglev::{MaglevBalancer, MaglevConfig};
107pub use p2c::{P2cBalancer, P2cConfig};
108pub use peak_ewma::{PeakEwmaBalancer, PeakEwmaConfig};
109pub use sticky_session::{StickySessionBalancer, StickySessionRuntimeConfig};
110pub use subset::{SubsetBalancer, SubsetConfig};
111pub use weighted_least_conn::{WeightedLeastConnBalancer, WeightedLeastConnConfig};
112
113#[derive(Debug, Clone)]
115pub struct RequestContext {
116 pub client_ip: Option<std::net::SocketAddr>,
117 pub headers: HashMap<String, String>,
118 pub path: String,
119 pub method: String,
120}
121
122#[async_trait]
124pub trait LoadBalancer: Send + Sync {
125 async fn select(&self, context: Option<&RequestContext>) -> SentinelResult<TargetSelection>;
127
128 async fn report_health(&self, address: &str, healthy: bool);
130
131 async fn healthy_targets(&self) -> Vec<String>;
133
134 async fn release(&self, _selection: &TargetSelection) {
136 }
138
139 async fn report_result(
141 &self,
142 _selection: &TargetSelection,
143 _success: bool,
144 _latency: Option<Duration>,
145 ) {
146 }
148
149 async fn report_result_with_latency(
156 &self,
157 address: &str,
158 success: bool,
159 _latency: Option<Duration>,
160 ) {
161 self.report_health(address, success).await;
163 }
164}
165
166#[derive(Debug, Clone)]
168pub struct TargetSelection {
169 pub address: String,
171 pub weight: u32,
173 pub metadata: HashMap<String, String>,
175}
176
177pub struct UpstreamPool {
179 id: UpstreamId,
181 targets: Vec<UpstreamTarget>,
183 load_balancer: Arc<dyn LoadBalancer>,
185 pool_config: ConnectionPoolConfig,
187 http_version: HttpVersionOptions,
189 tls_enabled: bool,
191 tls_sni: Option<String>,
193 tls_config: Option<sentinel_config::UpstreamTlsConfig>,
195 circuit_breakers: Arc<RwLock<HashMap<String, CircuitBreaker>>>,
197 stats: Arc<PoolStats>,
199}
200
201pub struct ConnectionPoolConfig {
210 pub max_connections: usize,
212 pub max_idle: usize,
214 pub idle_timeout: Duration,
216 pub max_lifetime: Option<Duration>,
218 pub connection_timeout: Duration,
220 pub read_timeout: Duration,
222 pub write_timeout: Duration,
224}
225
226pub struct HttpVersionOptions {
228 pub min_version: u8,
230 pub max_version: u8,
232 pub h2_ping_interval: Duration,
234 pub max_h2_streams: usize,
236}
237
238impl ConnectionPoolConfig {
239 pub fn from_config(
241 pool_config: &sentinel_config::ConnectionPoolConfig,
242 timeouts: &sentinel_config::UpstreamTimeouts,
243 ) -> Self {
244 Self {
245 max_connections: pool_config.max_connections,
246 max_idle: pool_config.max_idle,
247 idle_timeout: Duration::from_secs(pool_config.idle_timeout_secs),
248 max_lifetime: pool_config.max_lifetime_secs.map(Duration::from_secs),
249 connection_timeout: Duration::from_secs(timeouts.connect_secs),
250 read_timeout: Duration::from_secs(timeouts.read_secs),
251 write_timeout: Duration::from_secs(timeouts.write_secs),
252 }
253 }
254}
255
256#[derive(Default)]
260pub struct PoolStats {
261 pub requests: AtomicU64,
263 pub successes: AtomicU64,
265 pub failures: AtomicU64,
267 pub retries: AtomicU64,
269 pub circuit_breaker_trips: AtomicU64,
271}
272
273#[derive(Debug, Clone)]
275pub struct ShadowTarget {
276 pub scheme: String,
278 pub host: String,
280 pub port: u16,
282 pub sni: Option<String>,
284}
285
286impl ShadowTarget {
287 pub fn build_url(&self, path: &str) -> String {
289 let port_suffix = match (self.scheme.as_str(), self.port) {
290 ("http", 80) | ("https", 443) => String::new(),
291 _ => format!(":{}", self.port),
292 };
293 format!("{}://{}{}{}", self.scheme, self.host, port_suffix, path)
294 }
295}
296
297#[derive(Debug, Clone)]
299pub struct PoolConfigSnapshot {
300 pub max_connections: usize,
302 pub max_idle: usize,
304 pub idle_timeout_secs: u64,
306 pub max_lifetime_secs: Option<u64>,
308 pub connection_timeout_secs: u64,
310 pub read_timeout_secs: u64,
312 pub write_timeout_secs: u64,
314}
315
316struct RoundRobinBalancer {
318 targets: Vec<UpstreamTarget>,
319 current: AtomicUsize,
320 health_status: Arc<RwLock<HashMap<String, bool>>>,
321}
322
323impl RoundRobinBalancer {
324 fn new(targets: Vec<UpstreamTarget>) -> Self {
325 let mut health_status = HashMap::new();
326 for target in &targets {
327 health_status.insert(target.full_address(), true);
328 }
329
330 Self {
331 targets,
332 current: AtomicUsize::new(0),
333 health_status: Arc::new(RwLock::new(health_status)),
334 }
335 }
336}
337
338#[async_trait]
339impl LoadBalancer for RoundRobinBalancer {
340 async fn select(&self, _context: Option<&RequestContext>) -> SentinelResult<TargetSelection> {
341 trace!(
342 total_targets = self.targets.len(),
343 algorithm = "round_robin",
344 "Selecting upstream target"
345 );
346
347 let health = self.health_status.read().await;
348 let healthy_targets: Vec<_> = self
349 .targets
350 .iter()
351 .filter(|t| *health.get(&t.full_address()).unwrap_or(&true))
352 .collect();
353
354 if healthy_targets.is_empty() {
355 warn!(
356 total_targets = self.targets.len(),
357 algorithm = "round_robin",
358 "No healthy upstream targets available"
359 );
360 return Err(SentinelError::NoHealthyUpstream);
361 }
362
363 let index = self.current.fetch_add(1, Ordering::Relaxed) % healthy_targets.len();
364 let target = healthy_targets[index];
365
366 trace!(
367 selected_target = %target.full_address(),
368 healthy_count = healthy_targets.len(),
369 index = index,
370 algorithm = "round_robin",
371 "Selected target via round robin"
372 );
373
374 Ok(TargetSelection {
375 address: target.full_address(),
376 weight: target.weight,
377 metadata: HashMap::new(),
378 })
379 }
380
381 async fn report_health(&self, address: &str, healthy: bool) {
382 trace!(
383 target = %address,
384 healthy = healthy,
385 algorithm = "round_robin",
386 "Updating target health status"
387 );
388 self.health_status
389 .write()
390 .await
391 .insert(address.to_string(), healthy);
392 }
393
394 async fn healthy_targets(&self) -> Vec<String> {
395 self.health_status
396 .read()
397 .await
398 .iter()
399 .filter_map(|(addr, &healthy)| if healthy { Some(addr.clone()) } else { None })
400 .collect()
401 }
402}
403
404struct RandomBalancer {
406 targets: Vec<UpstreamTarget>,
407 health_status: Arc<RwLock<HashMap<String, bool>>>,
408}
409
410impl RandomBalancer {
411 fn new(targets: Vec<UpstreamTarget>) -> Self {
412 let mut health_status = HashMap::new();
413 for target in &targets {
414 health_status.insert(target.full_address(), true);
415 }
416
417 Self {
418 targets,
419 health_status: Arc::new(RwLock::new(health_status)),
420 }
421 }
422}
423
424#[async_trait]
425impl LoadBalancer for RandomBalancer {
426 async fn select(&self, _context: Option<&RequestContext>) -> SentinelResult<TargetSelection> {
427 use rand::seq::SliceRandom;
428
429 trace!(
430 total_targets = self.targets.len(),
431 algorithm = "random",
432 "Selecting upstream target"
433 );
434
435 let health = self.health_status.read().await;
436 let healthy_targets: Vec<_> = self
437 .targets
438 .iter()
439 .filter(|t| *health.get(&t.full_address()).unwrap_or(&true))
440 .collect();
441
442 if healthy_targets.is_empty() {
443 warn!(
444 total_targets = self.targets.len(),
445 algorithm = "random",
446 "No healthy upstream targets available"
447 );
448 return Err(SentinelError::NoHealthyUpstream);
449 }
450
451 let mut rng = rand::thread_rng();
452 let target = healthy_targets
453 .choose(&mut rng)
454 .ok_or(SentinelError::NoHealthyUpstream)?;
455
456 trace!(
457 selected_target = %target.full_address(),
458 healthy_count = healthy_targets.len(),
459 algorithm = "random",
460 "Selected target via random selection"
461 );
462
463 Ok(TargetSelection {
464 address: target.full_address(),
465 weight: target.weight,
466 metadata: HashMap::new(),
467 })
468 }
469
470 async fn report_health(&self, address: &str, healthy: bool) {
471 trace!(
472 target = %address,
473 healthy = healthy,
474 algorithm = "random",
475 "Updating target health status"
476 );
477 self.health_status
478 .write()
479 .await
480 .insert(address.to_string(), healthy);
481 }
482
483 async fn healthy_targets(&self) -> Vec<String> {
484 self.health_status
485 .read()
486 .await
487 .iter()
488 .filter_map(|(addr, &healthy)| if healthy { Some(addr.clone()) } else { None })
489 .collect()
490 }
491}
492
493struct LeastConnectionsBalancer {
495 targets: Vec<UpstreamTarget>,
496 connections: Arc<RwLock<HashMap<String, usize>>>,
497 health_status: Arc<RwLock<HashMap<String, bool>>>,
498}
499
500impl LeastConnectionsBalancer {
501 fn new(targets: Vec<UpstreamTarget>) -> Self {
502 let mut health_status = HashMap::new();
503 let mut connections = HashMap::new();
504
505 for target in &targets {
506 let addr = target.full_address();
507 health_status.insert(addr.clone(), true);
508 connections.insert(addr, 0);
509 }
510
511 Self {
512 targets,
513 connections: Arc::new(RwLock::new(connections)),
514 health_status: Arc::new(RwLock::new(health_status)),
515 }
516 }
517}
518
519#[async_trait]
520impl LoadBalancer for LeastConnectionsBalancer {
521 async fn select(&self, _context: Option<&RequestContext>) -> SentinelResult<TargetSelection> {
522 trace!(
523 total_targets = self.targets.len(),
524 algorithm = "least_connections",
525 "Selecting upstream target"
526 );
527
528 let health = self.health_status.read().await;
529 let conns = self.connections.read().await;
530
531 let mut best_target = None;
532 let mut min_connections = usize::MAX;
533
534 for target in &self.targets {
535 let addr = target.full_address();
536 if !*health.get(&addr).unwrap_or(&true) {
537 trace!(
538 target = %addr,
539 algorithm = "least_connections",
540 "Skipping unhealthy target"
541 );
542 continue;
543 }
544
545 let conn_count = *conns.get(&addr).unwrap_or(&0);
546 trace!(
547 target = %addr,
548 connections = conn_count,
549 "Evaluating target connection count"
550 );
551 if conn_count < min_connections {
552 min_connections = conn_count;
553 best_target = Some(target);
554 }
555 }
556
557 match best_target {
558 Some(target) => {
559 trace!(
560 selected_target = %target.full_address(),
561 connections = min_connections,
562 algorithm = "least_connections",
563 "Selected target with fewest connections"
564 );
565 Ok(TargetSelection {
566 address: target.full_address(),
567 weight: target.weight,
568 metadata: HashMap::new(),
569 })
570 }
571 None => {
572 warn!(
573 total_targets = self.targets.len(),
574 algorithm = "least_connections",
575 "No healthy upstream targets available"
576 );
577 Err(SentinelError::NoHealthyUpstream)
578 }
579 }
580 }
581
582 async fn report_health(&self, address: &str, healthy: bool) {
583 trace!(
584 target = %address,
585 healthy = healthy,
586 algorithm = "least_connections",
587 "Updating target health status"
588 );
589 self.health_status
590 .write()
591 .await
592 .insert(address.to_string(), healthy);
593 }
594
595 async fn healthy_targets(&self) -> Vec<String> {
596 self.health_status
597 .read()
598 .await
599 .iter()
600 .filter_map(|(addr, &healthy)| if healthy { Some(addr.clone()) } else { None })
601 .collect()
602 }
603}
604
605struct WeightedBalancer {
607 targets: Vec<UpstreamTarget>,
608 weights: Vec<u32>,
609 current_index: AtomicUsize,
610 health_status: Arc<RwLock<HashMap<String, bool>>>,
611}
612
613#[async_trait]
614impl LoadBalancer for WeightedBalancer {
615 async fn select(&self, _context: Option<&RequestContext>) -> SentinelResult<TargetSelection> {
616 trace!(
617 total_targets = self.targets.len(),
618 algorithm = "weighted",
619 "Selecting upstream target"
620 );
621
622 let health = self.health_status.read().await;
623 let healthy_indices: Vec<_> = self
624 .targets
625 .iter()
626 .enumerate()
627 .filter(|(_, t)| *health.get(&t.full_address()).unwrap_or(&true))
628 .map(|(i, _)| i)
629 .collect();
630
631 if healthy_indices.is_empty() {
632 warn!(
633 total_targets = self.targets.len(),
634 algorithm = "weighted",
635 "No healthy upstream targets available"
636 );
637 return Err(SentinelError::NoHealthyUpstream);
638 }
639
640 let idx = self.current_index.fetch_add(1, Ordering::Relaxed) % healthy_indices.len();
641 let target_idx = healthy_indices[idx];
642 let target = &self.targets[target_idx];
643 let weight = self.weights.get(target_idx).copied().unwrap_or(1);
644
645 trace!(
646 selected_target = %target.full_address(),
647 weight = weight,
648 healthy_count = healthy_indices.len(),
649 algorithm = "weighted",
650 "Selected target via weighted round robin"
651 );
652
653 Ok(TargetSelection {
654 address: target.full_address(),
655 weight,
656 metadata: HashMap::new(),
657 })
658 }
659
660 async fn report_health(&self, address: &str, healthy: bool) {
661 trace!(
662 target = %address,
663 healthy = healthy,
664 algorithm = "weighted",
665 "Updating target health status"
666 );
667 self.health_status
668 .write()
669 .await
670 .insert(address.to_string(), healthy);
671 }
672
673 async fn healthy_targets(&self) -> Vec<String> {
674 self.health_status
675 .read()
676 .await
677 .iter()
678 .filter_map(|(addr, &healthy)| if healthy { Some(addr.clone()) } else { None })
679 .collect()
680 }
681}
682
683struct IpHashBalancer {
685 targets: Vec<UpstreamTarget>,
686 health_status: Arc<RwLock<HashMap<String, bool>>>,
687}
688
689#[async_trait]
690impl LoadBalancer for IpHashBalancer {
691 async fn select(&self, context: Option<&RequestContext>) -> SentinelResult<TargetSelection> {
692 trace!(
693 total_targets = self.targets.len(),
694 algorithm = "ip_hash",
695 "Selecting upstream target"
696 );
697
698 let health = self.health_status.read().await;
699 let healthy_targets: Vec<_> = self
700 .targets
701 .iter()
702 .filter(|t| *health.get(&t.full_address()).unwrap_or(&true))
703 .collect();
704
705 if healthy_targets.is_empty() {
706 warn!(
707 total_targets = self.targets.len(),
708 algorithm = "ip_hash",
709 "No healthy upstream targets available"
710 );
711 return Err(SentinelError::NoHealthyUpstream);
712 }
713
714 let (hash, client_ip_str) = if let Some(ctx) = context {
716 if let Some(ip) = &ctx.client_ip {
717 use std::hash::{Hash, Hasher};
718 let mut hasher = std::collections::hash_map::DefaultHasher::new();
719 ip.hash(&mut hasher);
720 (hasher.finish(), Some(ip.to_string()))
721 } else {
722 (0, None)
723 }
724 } else {
725 (0, None)
726 };
727
728 let idx = (hash as usize) % healthy_targets.len();
729 let target = healthy_targets[idx];
730
731 trace!(
732 selected_target = %target.full_address(),
733 client_ip = client_ip_str.as_deref().unwrap_or("unknown"),
734 hash = hash,
735 index = idx,
736 healthy_count = healthy_targets.len(),
737 algorithm = "ip_hash",
738 "Selected target via IP hash"
739 );
740
741 Ok(TargetSelection {
742 address: target.full_address(),
743 weight: target.weight,
744 metadata: HashMap::new(),
745 })
746 }
747
748 async fn report_health(&self, address: &str, healthy: bool) {
749 trace!(
750 target = %address,
751 healthy = healthy,
752 algorithm = "ip_hash",
753 "Updating target health status"
754 );
755 self.health_status
756 .write()
757 .await
758 .insert(address.to_string(), healthy);
759 }
760
761 async fn healthy_targets(&self) -> Vec<String> {
762 self.health_status
763 .read()
764 .await
765 .iter()
766 .filter_map(|(addr, &healthy)| if healthy { Some(addr.clone()) } else { None })
767 .collect()
768 }
769}
770
771impl UpstreamPool {
772 pub async fn new(config: UpstreamConfig) -> SentinelResult<Self> {
774 let id = UpstreamId::new(&config.id);
775
776 info!(
777 upstream_id = %config.id,
778 target_count = config.targets.len(),
779 algorithm = ?config.load_balancing,
780 "Creating upstream pool"
781 );
782
783 let targets: Vec<UpstreamTarget> = config
785 .targets
786 .iter()
787 .filter_map(UpstreamTarget::from_config)
788 .collect();
789
790 if targets.is_empty() {
791 error!(
792 upstream_id = %config.id,
793 "No valid upstream targets configured"
794 );
795 return Err(SentinelError::Config {
796 message: "No valid upstream targets".to_string(),
797 source: None,
798 });
799 }
800
801 for target in &targets {
802 debug!(
803 upstream_id = %config.id,
804 target = %target.full_address(),
805 weight = target.weight,
806 "Registered upstream target"
807 );
808 }
809
810 debug!(
812 upstream_id = %config.id,
813 algorithm = ?config.load_balancing,
814 "Creating load balancer"
815 );
816 let load_balancer =
817 Self::create_load_balancer(&config.load_balancing, &targets, &config)?;
818
819 debug!(
821 upstream_id = %config.id,
822 max_connections = config.connection_pool.max_connections,
823 max_idle = config.connection_pool.max_idle,
824 idle_timeout_secs = config.connection_pool.idle_timeout_secs,
825 connect_timeout_secs = config.timeouts.connect_secs,
826 read_timeout_secs = config.timeouts.read_secs,
827 write_timeout_secs = config.timeouts.write_secs,
828 "Creating connection pool configuration"
829 );
830 let pool_config =
831 ConnectionPoolConfig::from_config(&config.connection_pool, &config.timeouts);
832
833 let http_version = HttpVersionOptions {
835 min_version: config.http_version.min_version,
836 max_version: config.http_version.max_version,
837 h2_ping_interval: if config.http_version.h2_ping_interval_secs > 0 {
838 Duration::from_secs(config.http_version.h2_ping_interval_secs)
839 } else {
840 Duration::ZERO
841 },
842 max_h2_streams: config.http_version.max_h2_streams,
843 };
844
845 let tls_enabled = config.tls.is_some();
847 let tls_sni = config.tls.as_ref().and_then(|t| t.sni.clone());
848 let tls_config = config.tls.clone();
849
850 if let Some(ref tls) = tls_config {
852 if tls.client_cert.is_some() {
853 info!(
854 upstream_id = %config.id,
855 "mTLS enabled for upstream (client certificate configured)"
856 );
857 }
858 }
859
860 if http_version.max_version >= 2 && tls_enabled {
861 info!(
862 upstream_id = %config.id,
863 "HTTP/2 enabled for upstream (via ALPN)"
864 );
865 }
866
867 let mut circuit_breakers = HashMap::new();
869 for target in &targets {
870 trace!(
871 upstream_id = %config.id,
872 target = %target.full_address(),
873 "Initializing circuit breaker for target"
874 );
875 circuit_breakers.insert(
876 target.full_address(),
877 CircuitBreaker::new(CircuitBreakerConfig::default()),
878 );
879 }
880
881 let pool = Self {
882 id: id.clone(),
883 targets,
884 load_balancer,
885 pool_config,
886 http_version,
887 tls_enabled,
888 tls_sni,
889 tls_config,
890 circuit_breakers: Arc::new(RwLock::new(circuit_breakers)),
891 stats: Arc::new(PoolStats::default()),
892 };
893
894 info!(
895 upstream_id = %id,
896 target_count = pool.targets.len(),
897 "Upstream pool created successfully"
898 );
899
900 Ok(pool)
901 }
902
903 fn create_load_balancer(
905 algorithm: &LoadBalancingAlgorithm,
906 targets: &[UpstreamTarget],
907 config: &UpstreamConfig,
908 ) -> SentinelResult<Arc<dyn LoadBalancer>> {
909 let balancer: Arc<dyn LoadBalancer> = match algorithm {
910 LoadBalancingAlgorithm::RoundRobin => {
911 Arc::new(RoundRobinBalancer::new(targets.to_vec()))
912 }
913 LoadBalancingAlgorithm::LeastConnections => {
914 Arc::new(LeastConnectionsBalancer::new(targets.to_vec()))
915 }
916 LoadBalancingAlgorithm::Weighted => {
917 let weights: Vec<u32> = targets.iter().map(|t| t.weight).collect();
918 Arc::new(WeightedBalancer {
919 targets: targets.to_vec(),
920 weights,
921 current_index: AtomicUsize::new(0),
922 health_status: Arc::new(RwLock::new(HashMap::new())),
923 })
924 }
925 LoadBalancingAlgorithm::IpHash => Arc::new(IpHashBalancer {
926 targets: targets.to_vec(),
927 health_status: Arc::new(RwLock::new(HashMap::new())),
928 }),
929 LoadBalancingAlgorithm::Random => {
930 Arc::new(RandomBalancer::new(targets.to_vec()))
931 }
932 LoadBalancingAlgorithm::ConsistentHash => Arc::new(ConsistentHashBalancer::new(
933 targets.to_vec(),
934 ConsistentHashConfig::default(),
935 )),
936 LoadBalancingAlgorithm::PowerOfTwoChoices => {
937 Arc::new(P2cBalancer::new(targets.to_vec(), P2cConfig::default()))
938 }
939 LoadBalancingAlgorithm::Adaptive => Arc::new(AdaptiveBalancer::new(
940 targets.to_vec(),
941 AdaptiveConfig::default(),
942 )),
943 LoadBalancingAlgorithm::LeastTokensQueued => Arc::new(LeastTokensQueuedBalancer::new(
944 targets.to_vec(),
945 LeastTokensQueuedConfig::default(),
946 )),
947 LoadBalancingAlgorithm::Maglev => Arc::new(MaglevBalancer::new(
948 targets.to_vec(),
949 MaglevConfig::default(),
950 )),
951 LoadBalancingAlgorithm::LocalityAware => Arc::new(LocalityAwareBalancer::new(
952 targets.to_vec(),
953 LocalityAwareConfig::default(),
954 )),
955 LoadBalancingAlgorithm::PeakEwma => Arc::new(PeakEwmaBalancer::new(
956 targets.to_vec(),
957 PeakEwmaConfig::default(),
958 )),
959 LoadBalancingAlgorithm::DeterministicSubset => Arc::new(SubsetBalancer::new(
960 targets.to_vec(),
961 SubsetConfig::default(),
962 )),
963 LoadBalancingAlgorithm::WeightedLeastConnections => {
964 Arc::new(WeightedLeastConnBalancer::new(
965 targets.to_vec(),
966 WeightedLeastConnConfig::default(),
967 ))
968 }
969 LoadBalancingAlgorithm::Sticky => {
970 let sticky_config = config.sticky_session.as_ref().ok_or_else(|| {
972 SentinelError::Config {
973 message: format!(
974 "Upstream '{}' uses Sticky algorithm but no sticky_session config provided",
975 config.id
976 ),
977 source: None,
978 }
979 })?;
980
981 let runtime_config = StickySessionRuntimeConfig::from_config(sticky_config);
983
984 let fallback = Self::create_load_balancer_inner(&sticky_config.fallback, targets)?;
986
987 info!(
988 upstream_id = %config.id,
989 cookie_name = %runtime_config.cookie_name,
990 cookie_ttl_secs = runtime_config.cookie_ttl_secs,
991 fallback_algorithm = ?sticky_config.fallback,
992 "Creating sticky session balancer"
993 );
994
995 Arc::new(StickySessionBalancer::new(
996 targets.to_vec(),
997 runtime_config,
998 fallback,
999 ))
1000 }
1001 };
1002 Ok(balancer)
1003 }
1004
1005 fn create_load_balancer_inner(
1007 algorithm: &LoadBalancingAlgorithm,
1008 targets: &[UpstreamTarget],
1009 ) -> SentinelResult<Arc<dyn LoadBalancer>> {
1010 let balancer: Arc<dyn LoadBalancer> = match algorithm {
1011 LoadBalancingAlgorithm::RoundRobin => {
1012 Arc::new(RoundRobinBalancer::new(targets.to_vec()))
1013 }
1014 LoadBalancingAlgorithm::LeastConnections => {
1015 Arc::new(LeastConnectionsBalancer::new(targets.to_vec()))
1016 }
1017 LoadBalancingAlgorithm::Weighted => {
1018 let weights: Vec<u32> = targets.iter().map(|t| t.weight).collect();
1019 Arc::new(WeightedBalancer {
1020 targets: targets.to_vec(),
1021 weights,
1022 current_index: AtomicUsize::new(0),
1023 health_status: Arc::new(RwLock::new(HashMap::new())),
1024 })
1025 }
1026 LoadBalancingAlgorithm::IpHash => Arc::new(IpHashBalancer {
1027 targets: targets.to_vec(),
1028 health_status: Arc::new(RwLock::new(HashMap::new())),
1029 }),
1030 LoadBalancingAlgorithm::Random => {
1031 Arc::new(RandomBalancer::new(targets.to_vec()))
1032 }
1033 LoadBalancingAlgorithm::ConsistentHash => Arc::new(ConsistentHashBalancer::new(
1034 targets.to_vec(),
1035 ConsistentHashConfig::default(),
1036 )),
1037 LoadBalancingAlgorithm::PowerOfTwoChoices => {
1038 Arc::new(P2cBalancer::new(targets.to_vec(), P2cConfig::default()))
1039 }
1040 LoadBalancingAlgorithm::Adaptive => Arc::new(AdaptiveBalancer::new(
1041 targets.to_vec(),
1042 AdaptiveConfig::default(),
1043 )),
1044 LoadBalancingAlgorithm::LeastTokensQueued => Arc::new(LeastTokensQueuedBalancer::new(
1045 targets.to_vec(),
1046 LeastTokensQueuedConfig::default(),
1047 )),
1048 LoadBalancingAlgorithm::Maglev => Arc::new(MaglevBalancer::new(
1049 targets.to_vec(),
1050 MaglevConfig::default(),
1051 )),
1052 LoadBalancingAlgorithm::LocalityAware => Arc::new(LocalityAwareBalancer::new(
1053 targets.to_vec(),
1054 LocalityAwareConfig::default(),
1055 )),
1056 LoadBalancingAlgorithm::PeakEwma => Arc::new(PeakEwmaBalancer::new(
1057 targets.to_vec(),
1058 PeakEwmaConfig::default(),
1059 )),
1060 LoadBalancingAlgorithm::DeterministicSubset => Arc::new(SubsetBalancer::new(
1061 targets.to_vec(),
1062 SubsetConfig::default(),
1063 )),
1064 LoadBalancingAlgorithm::WeightedLeastConnections => {
1065 Arc::new(WeightedLeastConnBalancer::new(
1066 targets.to_vec(),
1067 WeightedLeastConnConfig::default(),
1068 ))
1069 }
1070 LoadBalancingAlgorithm::Sticky => {
1071 return Err(SentinelError::Config {
1073 message: "Sticky algorithm cannot be used as fallback for sticky sessions"
1074 .to_string(),
1075 source: None,
1076 });
1077 }
1078 };
1079 Ok(balancer)
1080 }
1081
1082 pub async fn select_peer_with_metadata(
1088 &self,
1089 context: Option<&RequestContext>,
1090 ) -> SentinelResult<(HttpPeer, HashMap<String, String>)> {
1091 let request_num = self.stats.requests.fetch_add(1, Ordering::Relaxed) + 1;
1092
1093 trace!(
1094 upstream_id = %self.id,
1095 request_num = request_num,
1096 target_count = self.targets.len(),
1097 "Starting peer selection with metadata"
1098 );
1099
1100 let mut attempts = 0;
1101 let max_attempts = self.targets.len() * 2;
1102
1103 while attempts < max_attempts {
1104 attempts += 1;
1105
1106 trace!(
1107 upstream_id = %self.id,
1108 attempt = attempts,
1109 max_attempts = max_attempts,
1110 "Attempting to select peer"
1111 );
1112
1113 let selection = match self.load_balancer.select(context).await {
1114 Ok(s) => s,
1115 Err(e) => {
1116 warn!(
1117 upstream_id = %self.id,
1118 attempt = attempts,
1119 error = %e,
1120 "Load balancer selection failed"
1121 );
1122 continue;
1123 }
1124 };
1125
1126 trace!(
1127 upstream_id = %self.id,
1128 target = %selection.address,
1129 attempt = attempts,
1130 "Load balancer selected target"
1131 );
1132
1133 let breakers = self.circuit_breakers.read().await;
1135 if let Some(breaker) = breakers.get(&selection.address) {
1136 if !breaker.is_closed() {
1137 debug!(
1138 upstream_id = %self.id,
1139 target = %selection.address,
1140 attempt = attempts,
1141 "Circuit breaker is open, skipping target"
1142 );
1143 self.stats
1144 .circuit_breaker_trips
1145 .fetch_add(1, Ordering::Relaxed);
1146 continue;
1147 }
1148 }
1149
1150 trace!(
1152 upstream_id = %self.id,
1153 target = %selection.address,
1154 "Creating peer for upstream (Pingora handles connection reuse)"
1155 );
1156 let peer = self.create_peer(&selection)?;
1157
1158 debug!(
1159 upstream_id = %self.id,
1160 target = %selection.address,
1161 attempt = attempts,
1162 metadata_keys = ?selection.metadata.keys().collect::<Vec<_>>(),
1163 "Selected upstream peer with metadata"
1164 );
1165
1166 self.stats.successes.fetch_add(1, Ordering::Relaxed);
1167 return Ok((peer, selection.metadata));
1168 }
1169
1170 self.stats.failures.fetch_add(1, Ordering::Relaxed);
1171 error!(
1172 upstream_id = %self.id,
1173 attempts = attempts,
1174 max_attempts = max_attempts,
1175 "Failed to select upstream after max attempts"
1176 );
1177 Err(SentinelError::upstream(
1178 self.id.to_string(),
1179 "Failed to select upstream after max attempts",
1180 ))
1181 }
1182
1183 pub async fn select_peer(&self, context: Option<&RequestContext>) -> SentinelResult<HttpPeer> {
1185 self.select_peer_with_metadata(context)
1187 .await
1188 .map(|(peer, _)| peer)
1189 }
1190
1191 fn create_peer(&self, selection: &TargetSelection) -> SentinelResult<HttpPeer> {
1197 let sni_hostname = self.tls_sni.clone().unwrap_or_else(|| {
1199 selection
1201 .address
1202 .split(':')
1203 .next()
1204 .unwrap_or(&selection.address)
1205 .to_string()
1206 });
1207
1208 let resolved_address = selection
1211 .address
1212 .to_socket_addrs()
1213 .map_err(|e| {
1214 error!(
1215 upstream = %self.id,
1216 address = %selection.address,
1217 error = %e,
1218 "Failed to resolve upstream address"
1219 );
1220 SentinelError::Upstream {
1221 upstream: self.id.to_string(),
1222 message: format!("DNS resolution failed for {}: {}", selection.address, e),
1223 retryable: true,
1224 source: None,
1225 }
1226 })?
1227 .next()
1228 .ok_or_else(|| {
1229 error!(
1230 upstream = %self.id,
1231 address = %selection.address,
1232 "No addresses returned from DNS resolution"
1233 );
1234 SentinelError::Upstream {
1235 upstream: self.id.to_string(),
1236 message: format!("No addresses for {}", selection.address),
1237 retryable: true,
1238 source: None,
1239 }
1240 })?;
1241
1242 let mut peer = HttpPeer::new(resolved_address, self.tls_enabled, sni_hostname.clone());
1244
1245 peer.options.idle_timeout = Some(self.pool_config.idle_timeout);
1249
1250 peer.options.connection_timeout = Some(self.pool_config.connection_timeout);
1252 peer.options.total_connection_timeout = Some(Duration::from_secs(10));
1253
1254 peer.options.read_timeout = Some(self.pool_config.read_timeout);
1256 peer.options.write_timeout = Some(self.pool_config.write_timeout);
1257
1258 peer.options.tcp_keepalive = Some(pingora::protocols::TcpKeepalive {
1260 idle: Duration::from_secs(60),
1261 interval: Duration::from_secs(10),
1262 count: 3,
1263 #[cfg(target_os = "linux")]
1265 user_timeout: Duration::from_secs(60),
1266 });
1267
1268 if self.tls_enabled {
1270 let alpn = match (self.http_version.min_version, self.http_version.max_version) {
1272 (2, _) => {
1273 pingora::upstreams::peer::ALPN::H2
1275 }
1276 (1, 2) | (_, 2) => {
1277 pingora::upstreams::peer::ALPN::H2H1
1279 }
1280 _ => {
1281 pingora::upstreams::peer::ALPN::H1
1283 }
1284 };
1285 peer.options.alpn = alpn;
1286
1287 if let Some(ref tls_config) = self.tls_config {
1289 if tls_config.insecure_skip_verify {
1291 peer.options.verify_cert = false;
1292 peer.options.verify_hostname = false;
1293 warn!(
1294 upstream_id = %self.id,
1295 target = %selection.address,
1296 "TLS certificate verification DISABLED (insecure_skip_verify=true)"
1297 );
1298 }
1299
1300 if let Some(ref sni) = tls_config.sni {
1302 peer.options.alternative_cn = Some(sni.clone());
1303 trace!(
1304 upstream_id = %self.id,
1305 target = %selection.address,
1306 alternative_cn = %sni,
1307 "Set alternative CN for TLS verification"
1308 );
1309 }
1310
1311 if let (Some(cert_path), Some(key_path)) =
1313 (&tls_config.client_cert, &tls_config.client_key)
1314 {
1315 match crate::tls::load_client_cert_key(cert_path, key_path) {
1316 Ok(cert_key) => {
1317 peer.client_cert_key = Some(cert_key);
1318 info!(
1319 upstream_id = %self.id,
1320 target = %selection.address,
1321 cert_path = ?cert_path,
1322 "mTLS client certificate configured"
1323 );
1324 }
1325 Err(e) => {
1326 error!(
1327 upstream_id = %self.id,
1328 target = %selection.address,
1329 error = %e,
1330 "Failed to load mTLS client certificate"
1331 );
1332 return Err(SentinelError::Tls {
1333 message: format!("Failed to load client certificate: {}", e),
1334 source: None,
1335 });
1336 }
1337 }
1338 }
1339 }
1340
1341 trace!(
1342 upstream_id = %self.id,
1343 target = %selection.address,
1344 alpn = ?peer.options.alpn,
1345 min_version = self.http_version.min_version,
1346 max_version = self.http_version.max_version,
1347 verify_cert = peer.options.verify_cert,
1348 verify_hostname = peer.options.verify_hostname,
1349 "Configured ALPN and TLS options for HTTP version negotiation"
1350 );
1351 }
1352
1353 if self.http_version.max_version >= 2 {
1355 if !self.http_version.h2_ping_interval.is_zero() {
1357 peer.options.h2_ping_interval = Some(self.http_version.h2_ping_interval);
1358 trace!(
1359 upstream_id = %self.id,
1360 target = %selection.address,
1361 h2_ping_interval_secs = self.http_version.h2_ping_interval.as_secs(),
1362 "Configured H2 ping interval"
1363 );
1364 }
1365 }
1366
1367 trace!(
1368 upstream_id = %self.id,
1369 target = %selection.address,
1370 tls = self.tls_enabled,
1371 sni = %sni_hostname,
1372 idle_timeout_secs = self.pool_config.idle_timeout.as_secs(),
1373 http_max_version = self.http_version.max_version,
1374 "Created peer with Pingora connection pooling enabled"
1375 );
1376
1377 Ok(peer)
1378 }
1379
1380 pub async fn report_result(&self, target: &str, success: bool) {
1382 trace!(
1383 upstream_id = %self.id,
1384 target = %target,
1385 success = success,
1386 "Reporting connection result"
1387 );
1388
1389 if success {
1390 if let Some(breaker) = self.circuit_breakers.read().await.get(target) {
1391 breaker.record_success();
1392 trace!(
1393 upstream_id = %self.id,
1394 target = %target,
1395 "Recorded success in circuit breaker"
1396 );
1397 }
1398 self.load_balancer.report_health(target, true).await;
1399 } else {
1400 if let Some(breaker) = self.circuit_breakers.read().await.get(target) {
1401 breaker.record_failure();
1402 debug!(
1403 upstream_id = %self.id,
1404 target = %target,
1405 "Recorded failure in circuit breaker"
1406 );
1407 }
1408 self.load_balancer.report_health(target, false).await;
1409 self.stats.failures.fetch_add(1, Ordering::Relaxed);
1410 warn!(
1411 upstream_id = %self.id,
1412 target = %target,
1413 "Connection failure reported for target"
1414 );
1415 }
1416 }
1417
1418 pub async fn report_result_with_latency(
1424 &self,
1425 target: &str,
1426 success: bool,
1427 latency: Option<Duration>,
1428 ) {
1429 trace!(
1430 upstream_id = %self.id,
1431 target = %target,
1432 success = success,
1433 latency_ms = latency.map(|l| l.as_millis() as u64),
1434 "Reporting result with latency for adaptive LB"
1435 );
1436
1437 if success {
1439 if let Some(breaker) = self.circuit_breakers.read().await.get(target) {
1440 breaker.record_success();
1441 }
1442 } else {
1443 if let Some(breaker) = self.circuit_breakers.read().await.get(target) {
1444 breaker.record_failure();
1445 }
1446 self.stats.failures.fetch_add(1, Ordering::Relaxed);
1447 }
1448
1449 self.load_balancer
1451 .report_result_with_latency(target, success, latency)
1452 .await;
1453 }
1454
1455 pub fn stats(&self) -> &PoolStats {
1457 &self.stats
1458 }
1459
1460 pub fn id(&self) -> &UpstreamId {
1462 &self.id
1463 }
1464
1465 pub fn target_count(&self) -> usize {
1467 self.targets.len()
1468 }
1469
1470 pub fn pool_config(&self) -> PoolConfigSnapshot {
1472 PoolConfigSnapshot {
1473 max_connections: self.pool_config.max_connections,
1474 max_idle: self.pool_config.max_idle,
1475 idle_timeout_secs: self.pool_config.idle_timeout.as_secs(),
1476 max_lifetime_secs: self.pool_config.max_lifetime.map(|d| d.as_secs()),
1477 connection_timeout_secs: self.pool_config.connection_timeout.as_secs(),
1478 read_timeout_secs: self.pool_config.read_timeout.as_secs(),
1479 write_timeout_secs: self.pool_config.write_timeout.as_secs(),
1480 }
1481 }
1482
1483 pub async fn has_healthy_targets(&self) -> bool {
1487 let healthy = self.load_balancer.healthy_targets().await;
1488 !healthy.is_empty()
1489 }
1490
1491 pub async fn select_shadow_target(
1496 &self,
1497 context: Option<&RequestContext>,
1498 ) -> SentinelResult<ShadowTarget> {
1499 let selection = self.load_balancer.select(context).await?;
1501
1502 let breakers = self.circuit_breakers.read().await;
1504 if let Some(breaker) = breakers.get(&selection.address) {
1505 if !breaker.is_closed() {
1506 return Err(SentinelError::upstream(
1507 self.id.to_string(),
1508 "Circuit breaker is open for shadow target",
1509 ));
1510 }
1511 }
1512
1513 let (host, port) = if selection.address.contains(':') {
1515 let parts: Vec<&str> = selection.address.rsplitn(2, ':').collect();
1516 if parts.len() == 2 {
1517 (
1518 parts[1].to_string(),
1519 parts[0].parse::<u16>().unwrap_or(if self.tls_enabled { 443 } else { 80 }),
1520 )
1521 } else {
1522 (selection.address.clone(), if self.tls_enabled { 443 } else { 80 })
1523 }
1524 } else {
1525 (selection.address.clone(), if self.tls_enabled { 443 } else { 80 })
1526 };
1527
1528 Ok(ShadowTarget {
1529 scheme: if self.tls_enabled { "https" } else { "http" }.to_string(),
1530 host,
1531 port,
1532 sni: self.tls_sni.clone(),
1533 })
1534 }
1535
1536 pub fn is_tls_enabled(&self) -> bool {
1538 self.tls_enabled
1539 }
1540
1541 pub async fn shutdown(&self) {
1545 info!(
1546 upstream_id = %self.id,
1547 target_count = self.targets.len(),
1548 total_requests = self.stats.requests.load(Ordering::Relaxed),
1549 total_successes = self.stats.successes.load(Ordering::Relaxed),
1550 total_failures = self.stats.failures.load(Ordering::Relaxed),
1551 "Shutting down upstream pool"
1552 );
1553 debug!(upstream_id = %self.id, "Upstream pool shutdown complete");
1555 }
1556}