1use async_trait::async_trait;
7use pingora::upstreams::peer::HttpPeer;
8use std::collections::HashMap;
9use std::net::ToSocketAddrs;
10use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
11use std::sync::Arc;
12use std::time::Duration;
13use tokio::sync::RwLock;
14use tracing::{debug, error, info, trace, warn};
15
16use sentinel_common::{
17 errors::{SentinelError, SentinelResult},
18 types::{CircuitBreakerConfig, LoadBalancingAlgorithm},
19 CircuitBreaker, UpstreamId,
20};
21use sentinel_config::UpstreamConfig;
22
23#[derive(Debug, Clone)]
32pub struct UpstreamTarget {
33 pub address: String,
35 pub port: u16,
37 pub weight: u32,
39}
40
41impl UpstreamTarget {
42 pub fn new(address: impl Into<String>, port: u16, weight: u32) -> Self {
44 Self {
45 address: address.into(),
46 port,
47 weight,
48 }
49 }
50
51 pub fn from_address(addr: &str) -> Option<Self> {
53 let parts: Vec<&str> = addr.rsplitn(2, ':').collect();
54 if parts.len() == 2 {
55 let port = parts[0].parse().ok()?;
56 let address = parts[1].to_string();
57 Some(Self {
58 address,
59 port,
60 weight: 100,
61 })
62 } else {
63 None
64 }
65 }
66
67 pub fn from_config(config: &sentinel_config::UpstreamTarget) -> Option<Self> {
69 Self::from_address(&config.address).map(|mut t| {
70 t.weight = config.weight;
71 t
72 })
73 }
74
75 pub fn full_address(&self) -> String {
77 format!("{}:{}", self.address, self.port)
78 }
79}
80
81pub mod adaptive;
87pub mod consistent_hash;
88pub mod health;
89pub mod inference_health;
90pub mod least_tokens;
91pub mod locality;
92pub mod maglev;
93pub mod p2c;
94pub mod peak_ewma;
95pub mod subset;
96pub mod weighted_least_conn;
97
98pub use adaptive::{AdaptiveBalancer, AdaptiveConfig};
100pub use consistent_hash::{ConsistentHashBalancer, ConsistentHashConfig};
101pub use health::{ActiveHealthChecker, HealthCheckRunner};
102pub use inference_health::InferenceHealthCheck;
103pub use least_tokens::{LeastTokensQueuedBalancer, LeastTokensQueuedConfig, LeastTokensQueuedTargetStats};
104pub use locality::{LocalityAwareBalancer, LocalityAwareConfig};
105pub use maglev::{MaglevBalancer, MaglevConfig};
106pub use p2c::{P2cBalancer, P2cConfig};
107pub use peak_ewma::{PeakEwmaBalancer, PeakEwmaConfig};
108pub use subset::{SubsetBalancer, SubsetConfig};
109pub use weighted_least_conn::{WeightedLeastConnBalancer, WeightedLeastConnConfig};
110
111#[derive(Debug, Clone)]
113pub struct RequestContext {
114 pub client_ip: Option<std::net::SocketAddr>,
115 pub headers: HashMap<String, String>,
116 pub path: String,
117 pub method: String,
118}
119
120#[async_trait]
122pub trait LoadBalancer: Send + Sync {
123 async fn select(&self, context: Option<&RequestContext>) -> SentinelResult<TargetSelection>;
125
126 async fn report_health(&self, address: &str, healthy: bool);
128
129 async fn healthy_targets(&self) -> Vec<String>;
131
132 async fn release(&self, _selection: &TargetSelection) {
134 }
136
137 async fn report_result(
139 &self,
140 _selection: &TargetSelection,
141 _success: bool,
142 _latency: Option<Duration>,
143 ) {
144 }
146
147 async fn report_result_with_latency(
154 &self,
155 address: &str,
156 success: bool,
157 _latency: Option<Duration>,
158 ) {
159 self.report_health(address, success).await;
161 }
162}
163
164#[derive(Debug, Clone)]
166pub struct TargetSelection {
167 pub address: String,
169 pub weight: u32,
171 pub metadata: HashMap<String, String>,
173}
174
175pub struct UpstreamPool {
177 id: UpstreamId,
179 targets: Vec<UpstreamTarget>,
181 load_balancer: Arc<dyn LoadBalancer>,
183 pool_config: ConnectionPoolConfig,
185 http_version: HttpVersionOptions,
187 tls_enabled: bool,
189 tls_sni: Option<String>,
191 tls_config: Option<sentinel_config::UpstreamTlsConfig>,
193 circuit_breakers: Arc<RwLock<HashMap<String, CircuitBreaker>>>,
195 stats: Arc<PoolStats>,
197}
198
199pub struct ConnectionPoolConfig {
208 pub max_connections: usize,
210 pub max_idle: usize,
212 pub idle_timeout: Duration,
214 pub max_lifetime: Option<Duration>,
216 pub connection_timeout: Duration,
218 pub read_timeout: Duration,
220 pub write_timeout: Duration,
222}
223
224pub struct HttpVersionOptions {
226 pub min_version: u8,
228 pub max_version: u8,
230 pub h2_ping_interval: Duration,
232 pub max_h2_streams: usize,
234}
235
236impl ConnectionPoolConfig {
237 pub fn from_config(
239 pool_config: &sentinel_config::ConnectionPoolConfig,
240 timeouts: &sentinel_config::UpstreamTimeouts,
241 ) -> Self {
242 Self {
243 max_connections: pool_config.max_connections,
244 max_idle: pool_config.max_idle,
245 idle_timeout: Duration::from_secs(pool_config.idle_timeout_secs),
246 max_lifetime: pool_config.max_lifetime_secs.map(Duration::from_secs),
247 connection_timeout: Duration::from_secs(timeouts.connect_secs),
248 read_timeout: Duration::from_secs(timeouts.read_secs),
249 write_timeout: Duration::from_secs(timeouts.write_secs),
250 }
251 }
252}
253
254#[derive(Default)]
258pub struct PoolStats {
259 pub requests: AtomicU64,
261 pub successes: AtomicU64,
263 pub failures: AtomicU64,
265 pub retries: AtomicU64,
267 pub circuit_breaker_trips: AtomicU64,
269}
270
271#[derive(Debug, Clone)]
273pub struct ShadowTarget {
274 pub scheme: String,
276 pub host: String,
278 pub port: u16,
280 pub sni: Option<String>,
282}
283
284impl ShadowTarget {
285 pub fn build_url(&self, path: &str) -> String {
287 let port_suffix = match (self.scheme.as_str(), self.port) {
288 ("http", 80) | ("https", 443) => String::new(),
289 _ => format!(":{}", self.port),
290 };
291 format!("{}://{}{}{}", self.scheme, self.host, port_suffix, path)
292 }
293}
294
295#[derive(Debug, Clone)]
297pub struct PoolConfigSnapshot {
298 pub max_connections: usize,
300 pub max_idle: usize,
302 pub idle_timeout_secs: u64,
304 pub max_lifetime_secs: Option<u64>,
306 pub connection_timeout_secs: u64,
308 pub read_timeout_secs: u64,
310 pub write_timeout_secs: u64,
312}
313
314struct RoundRobinBalancer {
316 targets: Vec<UpstreamTarget>,
317 current: AtomicUsize,
318 health_status: Arc<RwLock<HashMap<String, bool>>>,
319}
320
321impl RoundRobinBalancer {
322 fn new(targets: Vec<UpstreamTarget>) -> Self {
323 let mut health_status = HashMap::new();
324 for target in &targets {
325 health_status.insert(target.full_address(), true);
326 }
327
328 Self {
329 targets,
330 current: AtomicUsize::new(0),
331 health_status: Arc::new(RwLock::new(health_status)),
332 }
333 }
334}
335
336#[async_trait]
337impl LoadBalancer for RoundRobinBalancer {
338 async fn select(&self, _context: Option<&RequestContext>) -> SentinelResult<TargetSelection> {
339 trace!(
340 total_targets = self.targets.len(),
341 algorithm = "round_robin",
342 "Selecting upstream target"
343 );
344
345 let health = self.health_status.read().await;
346 let healthy_targets: Vec<_> = self
347 .targets
348 .iter()
349 .filter(|t| *health.get(&t.full_address()).unwrap_or(&true))
350 .collect();
351
352 if healthy_targets.is_empty() {
353 warn!(
354 total_targets = self.targets.len(),
355 algorithm = "round_robin",
356 "No healthy upstream targets available"
357 );
358 return Err(SentinelError::NoHealthyUpstream);
359 }
360
361 let index = self.current.fetch_add(1, Ordering::Relaxed) % healthy_targets.len();
362 let target = healthy_targets[index];
363
364 trace!(
365 selected_target = %target.full_address(),
366 healthy_count = healthy_targets.len(),
367 index = index,
368 algorithm = "round_robin",
369 "Selected target via round robin"
370 );
371
372 Ok(TargetSelection {
373 address: target.full_address(),
374 weight: target.weight,
375 metadata: HashMap::new(),
376 })
377 }
378
379 async fn report_health(&self, address: &str, healthy: bool) {
380 trace!(
381 target = %address,
382 healthy = healthy,
383 algorithm = "round_robin",
384 "Updating target health status"
385 );
386 self.health_status
387 .write()
388 .await
389 .insert(address.to_string(), healthy);
390 }
391
392 async fn healthy_targets(&self) -> Vec<String> {
393 self.health_status
394 .read()
395 .await
396 .iter()
397 .filter_map(|(addr, &healthy)| if healthy { Some(addr.clone()) } else { None })
398 .collect()
399 }
400}
401
402struct RandomBalancer {
404 targets: Vec<UpstreamTarget>,
405 health_status: Arc<RwLock<HashMap<String, bool>>>,
406}
407
408impl RandomBalancer {
409 fn new(targets: Vec<UpstreamTarget>) -> Self {
410 let mut health_status = HashMap::new();
411 for target in &targets {
412 health_status.insert(target.full_address(), true);
413 }
414
415 Self {
416 targets,
417 health_status: Arc::new(RwLock::new(health_status)),
418 }
419 }
420}
421
422#[async_trait]
423impl LoadBalancer for RandomBalancer {
424 async fn select(&self, _context: Option<&RequestContext>) -> SentinelResult<TargetSelection> {
425 use rand::seq::SliceRandom;
426
427 trace!(
428 total_targets = self.targets.len(),
429 algorithm = "random",
430 "Selecting upstream target"
431 );
432
433 let health = self.health_status.read().await;
434 let healthy_targets: Vec<_> = self
435 .targets
436 .iter()
437 .filter(|t| *health.get(&t.full_address()).unwrap_or(&true))
438 .collect();
439
440 if healthy_targets.is_empty() {
441 warn!(
442 total_targets = self.targets.len(),
443 algorithm = "random",
444 "No healthy upstream targets available"
445 );
446 return Err(SentinelError::NoHealthyUpstream);
447 }
448
449 let mut rng = rand::thread_rng();
450 let target = healthy_targets
451 .choose(&mut rng)
452 .ok_or(SentinelError::NoHealthyUpstream)?;
453
454 trace!(
455 selected_target = %target.full_address(),
456 healthy_count = healthy_targets.len(),
457 algorithm = "random",
458 "Selected target via random selection"
459 );
460
461 Ok(TargetSelection {
462 address: target.full_address(),
463 weight: target.weight,
464 metadata: HashMap::new(),
465 })
466 }
467
468 async fn report_health(&self, address: &str, healthy: bool) {
469 trace!(
470 target = %address,
471 healthy = healthy,
472 algorithm = "random",
473 "Updating target health status"
474 );
475 self.health_status
476 .write()
477 .await
478 .insert(address.to_string(), healthy);
479 }
480
481 async fn healthy_targets(&self) -> Vec<String> {
482 self.health_status
483 .read()
484 .await
485 .iter()
486 .filter_map(|(addr, &healthy)| if healthy { Some(addr.clone()) } else { None })
487 .collect()
488 }
489}
490
491struct LeastConnectionsBalancer {
493 targets: Vec<UpstreamTarget>,
494 connections: Arc<RwLock<HashMap<String, usize>>>,
495 health_status: Arc<RwLock<HashMap<String, bool>>>,
496}
497
498impl LeastConnectionsBalancer {
499 fn new(targets: Vec<UpstreamTarget>) -> Self {
500 let mut health_status = HashMap::new();
501 let mut connections = HashMap::new();
502
503 for target in &targets {
504 let addr = target.full_address();
505 health_status.insert(addr.clone(), true);
506 connections.insert(addr, 0);
507 }
508
509 Self {
510 targets,
511 connections: Arc::new(RwLock::new(connections)),
512 health_status: Arc::new(RwLock::new(health_status)),
513 }
514 }
515}
516
517#[async_trait]
518impl LoadBalancer for LeastConnectionsBalancer {
519 async fn select(&self, _context: Option<&RequestContext>) -> SentinelResult<TargetSelection> {
520 trace!(
521 total_targets = self.targets.len(),
522 algorithm = "least_connections",
523 "Selecting upstream target"
524 );
525
526 let health = self.health_status.read().await;
527 let conns = self.connections.read().await;
528
529 let mut best_target = None;
530 let mut min_connections = usize::MAX;
531
532 for target in &self.targets {
533 let addr = target.full_address();
534 if !*health.get(&addr).unwrap_or(&true) {
535 trace!(
536 target = %addr,
537 algorithm = "least_connections",
538 "Skipping unhealthy target"
539 );
540 continue;
541 }
542
543 let conn_count = *conns.get(&addr).unwrap_or(&0);
544 trace!(
545 target = %addr,
546 connections = conn_count,
547 "Evaluating target connection count"
548 );
549 if conn_count < min_connections {
550 min_connections = conn_count;
551 best_target = Some(target);
552 }
553 }
554
555 match best_target {
556 Some(target) => {
557 trace!(
558 selected_target = %target.full_address(),
559 connections = min_connections,
560 algorithm = "least_connections",
561 "Selected target with fewest connections"
562 );
563 Ok(TargetSelection {
564 address: target.full_address(),
565 weight: target.weight,
566 metadata: HashMap::new(),
567 })
568 }
569 None => {
570 warn!(
571 total_targets = self.targets.len(),
572 algorithm = "least_connections",
573 "No healthy upstream targets available"
574 );
575 Err(SentinelError::NoHealthyUpstream)
576 }
577 }
578 }
579
580 async fn report_health(&self, address: &str, healthy: bool) {
581 trace!(
582 target = %address,
583 healthy = healthy,
584 algorithm = "least_connections",
585 "Updating target health status"
586 );
587 self.health_status
588 .write()
589 .await
590 .insert(address.to_string(), healthy);
591 }
592
593 async fn healthy_targets(&self) -> Vec<String> {
594 self.health_status
595 .read()
596 .await
597 .iter()
598 .filter_map(|(addr, &healthy)| if healthy { Some(addr.clone()) } else { None })
599 .collect()
600 }
601}
602
603struct WeightedBalancer {
605 targets: Vec<UpstreamTarget>,
606 weights: Vec<u32>,
607 current_index: AtomicUsize,
608 health_status: Arc<RwLock<HashMap<String, bool>>>,
609}
610
611#[async_trait]
612impl LoadBalancer for WeightedBalancer {
613 async fn select(&self, _context: Option<&RequestContext>) -> SentinelResult<TargetSelection> {
614 trace!(
615 total_targets = self.targets.len(),
616 algorithm = "weighted",
617 "Selecting upstream target"
618 );
619
620 let health = self.health_status.read().await;
621 let healthy_indices: Vec<_> = self
622 .targets
623 .iter()
624 .enumerate()
625 .filter(|(_, t)| *health.get(&t.full_address()).unwrap_or(&true))
626 .map(|(i, _)| i)
627 .collect();
628
629 if healthy_indices.is_empty() {
630 warn!(
631 total_targets = self.targets.len(),
632 algorithm = "weighted",
633 "No healthy upstream targets available"
634 );
635 return Err(SentinelError::NoHealthyUpstream);
636 }
637
638 let idx = self.current_index.fetch_add(1, Ordering::Relaxed) % healthy_indices.len();
639 let target_idx = healthy_indices[idx];
640 let target = &self.targets[target_idx];
641 let weight = self.weights.get(target_idx).copied().unwrap_or(1);
642
643 trace!(
644 selected_target = %target.full_address(),
645 weight = weight,
646 healthy_count = healthy_indices.len(),
647 algorithm = "weighted",
648 "Selected target via weighted round robin"
649 );
650
651 Ok(TargetSelection {
652 address: target.full_address(),
653 weight,
654 metadata: HashMap::new(),
655 })
656 }
657
658 async fn report_health(&self, address: &str, healthy: bool) {
659 trace!(
660 target = %address,
661 healthy = healthy,
662 algorithm = "weighted",
663 "Updating target health status"
664 );
665 self.health_status
666 .write()
667 .await
668 .insert(address.to_string(), healthy);
669 }
670
671 async fn healthy_targets(&self) -> Vec<String> {
672 self.health_status
673 .read()
674 .await
675 .iter()
676 .filter_map(|(addr, &healthy)| if healthy { Some(addr.clone()) } else { None })
677 .collect()
678 }
679}
680
681struct IpHashBalancer {
683 targets: Vec<UpstreamTarget>,
684 health_status: Arc<RwLock<HashMap<String, bool>>>,
685}
686
687#[async_trait]
688impl LoadBalancer for IpHashBalancer {
689 async fn select(&self, context: Option<&RequestContext>) -> SentinelResult<TargetSelection> {
690 trace!(
691 total_targets = self.targets.len(),
692 algorithm = "ip_hash",
693 "Selecting upstream target"
694 );
695
696 let health = self.health_status.read().await;
697 let healthy_targets: Vec<_> = self
698 .targets
699 .iter()
700 .filter(|t| *health.get(&t.full_address()).unwrap_or(&true))
701 .collect();
702
703 if healthy_targets.is_empty() {
704 warn!(
705 total_targets = self.targets.len(),
706 algorithm = "ip_hash",
707 "No healthy upstream targets available"
708 );
709 return Err(SentinelError::NoHealthyUpstream);
710 }
711
712 let (hash, client_ip_str) = if let Some(ctx) = context {
714 if let Some(ip) = &ctx.client_ip {
715 use std::hash::{Hash, Hasher};
716 let mut hasher = std::collections::hash_map::DefaultHasher::new();
717 ip.hash(&mut hasher);
718 (hasher.finish(), Some(ip.to_string()))
719 } else {
720 (0, None)
721 }
722 } else {
723 (0, None)
724 };
725
726 let idx = (hash as usize) % healthy_targets.len();
727 let target = healthy_targets[idx];
728
729 trace!(
730 selected_target = %target.full_address(),
731 client_ip = client_ip_str.as_deref().unwrap_or("unknown"),
732 hash = hash,
733 index = idx,
734 healthy_count = healthy_targets.len(),
735 algorithm = "ip_hash",
736 "Selected target via IP hash"
737 );
738
739 Ok(TargetSelection {
740 address: target.full_address(),
741 weight: target.weight,
742 metadata: HashMap::new(),
743 })
744 }
745
746 async fn report_health(&self, address: &str, healthy: bool) {
747 trace!(
748 target = %address,
749 healthy = healthy,
750 algorithm = "ip_hash",
751 "Updating target health status"
752 );
753 self.health_status
754 .write()
755 .await
756 .insert(address.to_string(), healthy);
757 }
758
759 async fn healthy_targets(&self) -> Vec<String> {
760 self.health_status
761 .read()
762 .await
763 .iter()
764 .filter_map(|(addr, &healthy)| if healthy { Some(addr.clone()) } else { None })
765 .collect()
766 }
767}
768
769impl UpstreamPool {
770 pub async fn new(config: UpstreamConfig) -> SentinelResult<Self> {
772 let id = UpstreamId::new(&config.id);
773
774 info!(
775 upstream_id = %config.id,
776 target_count = config.targets.len(),
777 algorithm = ?config.load_balancing,
778 "Creating upstream pool"
779 );
780
781 let targets: Vec<UpstreamTarget> = config
783 .targets
784 .iter()
785 .filter_map(UpstreamTarget::from_config)
786 .collect();
787
788 if targets.is_empty() {
789 error!(
790 upstream_id = %config.id,
791 "No valid upstream targets configured"
792 );
793 return Err(SentinelError::Config {
794 message: "No valid upstream targets".to_string(),
795 source: None,
796 });
797 }
798
799 for target in &targets {
800 debug!(
801 upstream_id = %config.id,
802 target = %target.full_address(),
803 weight = target.weight,
804 "Registered upstream target"
805 );
806 }
807
808 debug!(
810 upstream_id = %config.id,
811 algorithm = ?config.load_balancing,
812 "Creating load balancer"
813 );
814 let load_balancer = Self::create_load_balancer(&config.load_balancing, &targets)?;
815
816 debug!(
818 upstream_id = %config.id,
819 max_connections = config.connection_pool.max_connections,
820 max_idle = config.connection_pool.max_idle,
821 idle_timeout_secs = config.connection_pool.idle_timeout_secs,
822 connect_timeout_secs = config.timeouts.connect_secs,
823 read_timeout_secs = config.timeouts.read_secs,
824 write_timeout_secs = config.timeouts.write_secs,
825 "Creating connection pool configuration"
826 );
827 let pool_config =
828 ConnectionPoolConfig::from_config(&config.connection_pool, &config.timeouts);
829
830 let http_version = HttpVersionOptions {
832 min_version: config.http_version.min_version,
833 max_version: config.http_version.max_version,
834 h2_ping_interval: if config.http_version.h2_ping_interval_secs > 0 {
835 Duration::from_secs(config.http_version.h2_ping_interval_secs)
836 } else {
837 Duration::ZERO
838 },
839 max_h2_streams: config.http_version.max_h2_streams,
840 };
841
842 let tls_enabled = config.tls.is_some();
844 let tls_sni = config.tls.as_ref().and_then(|t| t.sni.clone());
845 let tls_config = config.tls.clone();
846
847 if let Some(ref tls) = tls_config {
849 if tls.client_cert.is_some() {
850 info!(
851 upstream_id = %config.id,
852 "mTLS enabled for upstream (client certificate configured)"
853 );
854 }
855 }
856
857 if http_version.max_version >= 2 && tls_enabled {
858 info!(
859 upstream_id = %config.id,
860 "HTTP/2 enabled for upstream (via ALPN)"
861 );
862 }
863
864 let mut circuit_breakers = HashMap::new();
866 for target in &targets {
867 trace!(
868 upstream_id = %config.id,
869 target = %target.full_address(),
870 "Initializing circuit breaker for target"
871 );
872 circuit_breakers.insert(
873 target.full_address(),
874 CircuitBreaker::new(CircuitBreakerConfig::default()),
875 );
876 }
877
878 let pool = Self {
879 id: id.clone(),
880 targets,
881 load_balancer,
882 pool_config,
883 http_version,
884 tls_enabled,
885 tls_sni,
886 tls_config,
887 circuit_breakers: Arc::new(RwLock::new(circuit_breakers)),
888 stats: Arc::new(PoolStats::default()),
889 };
890
891 info!(
892 upstream_id = %id,
893 target_count = pool.targets.len(),
894 "Upstream pool created successfully"
895 );
896
897 Ok(pool)
898 }
899
900 fn create_load_balancer(
902 algorithm: &LoadBalancingAlgorithm,
903 targets: &[UpstreamTarget],
904 ) -> SentinelResult<Arc<dyn LoadBalancer>> {
905 let balancer: Arc<dyn LoadBalancer> = match algorithm {
906 LoadBalancingAlgorithm::RoundRobin => {
907 Arc::new(RoundRobinBalancer::new(targets.to_vec()))
908 }
909 LoadBalancingAlgorithm::LeastConnections => {
910 Arc::new(LeastConnectionsBalancer::new(targets.to_vec()))
911 }
912 LoadBalancingAlgorithm::Weighted => {
913 let weights: Vec<u32> = targets.iter().map(|t| t.weight).collect();
914 Arc::new(WeightedBalancer {
915 targets: targets.to_vec(),
916 weights,
917 current_index: AtomicUsize::new(0),
918 health_status: Arc::new(RwLock::new(HashMap::new())),
919 })
920 }
921 LoadBalancingAlgorithm::IpHash => Arc::new(IpHashBalancer {
922 targets: targets.to_vec(),
923 health_status: Arc::new(RwLock::new(HashMap::new())),
924 }),
925 LoadBalancingAlgorithm::Random => {
926 Arc::new(RandomBalancer::new(targets.to_vec()))
927 }
928 LoadBalancingAlgorithm::ConsistentHash => Arc::new(ConsistentHashBalancer::new(
929 targets.to_vec(),
930 ConsistentHashConfig::default(),
931 )),
932 LoadBalancingAlgorithm::PowerOfTwoChoices => {
933 Arc::new(P2cBalancer::new(targets.to_vec(), P2cConfig::default()))
934 }
935 LoadBalancingAlgorithm::Adaptive => Arc::new(AdaptiveBalancer::new(
936 targets.to_vec(),
937 AdaptiveConfig::default(),
938 )),
939 LoadBalancingAlgorithm::LeastTokensQueued => Arc::new(LeastTokensQueuedBalancer::new(
940 targets.to_vec(),
941 LeastTokensQueuedConfig::default(),
942 )),
943 LoadBalancingAlgorithm::Maglev => Arc::new(MaglevBalancer::new(
944 targets.to_vec(),
945 MaglevConfig::default(),
946 )),
947 LoadBalancingAlgorithm::LocalityAware => Arc::new(LocalityAwareBalancer::new(
948 targets.to_vec(),
949 LocalityAwareConfig::default(),
950 )),
951 LoadBalancingAlgorithm::PeakEwma => Arc::new(PeakEwmaBalancer::new(
952 targets.to_vec(),
953 PeakEwmaConfig::default(),
954 )),
955 LoadBalancingAlgorithm::DeterministicSubset => Arc::new(SubsetBalancer::new(
956 targets.to_vec(),
957 SubsetConfig::default(),
958 )),
959 LoadBalancingAlgorithm::WeightedLeastConnections => {
960 Arc::new(WeightedLeastConnBalancer::new(
961 targets.to_vec(),
962 WeightedLeastConnConfig::default(),
963 ))
964 }
965 };
966 Ok(balancer)
967 }
968
969 pub async fn select_peer(&self, context: Option<&RequestContext>) -> SentinelResult<HttpPeer> {
971 let request_num = self.stats.requests.fetch_add(1, Ordering::Relaxed) + 1;
972
973 trace!(
974 upstream_id = %self.id,
975 request_num = request_num,
976 target_count = self.targets.len(),
977 "Starting peer selection"
978 );
979
980 let mut attempts = 0;
981 let max_attempts = self.targets.len() * 2;
982
983 while attempts < max_attempts {
984 attempts += 1;
985
986 trace!(
987 upstream_id = %self.id,
988 attempt = attempts,
989 max_attempts = max_attempts,
990 "Attempting to select peer"
991 );
992
993 let selection = match self.load_balancer.select(context).await {
994 Ok(s) => s,
995 Err(e) => {
996 warn!(
997 upstream_id = %self.id,
998 attempt = attempts,
999 error = %e,
1000 "Load balancer selection failed"
1001 );
1002 continue;
1003 }
1004 };
1005
1006 trace!(
1007 upstream_id = %self.id,
1008 target = %selection.address,
1009 attempt = attempts,
1010 "Load balancer selected target"
1011 );
1012
1013 let breakers = self.circuit_breakers.read().await;
1015 if let Some(breaker) = breakers.get(&selection.address) {
1016 if !breaker.is_closed().await {
1017 debug!(
1018 upstream_id = %self.id,
1019 target = %selection.address,
1020 attempt = attempts,
1021 "Circuit breaker is open, skipping target"
1022 );
1023 self.stats
1024 .circuit_breaker_trips
1025 .fetch_add(1, Ordering::Relaxed);
1026 continue;
1027 }
1028 }
1029
1030 trace!(
1034 upstream_id = %self.id,
1035 target = %selection.address,
1036 "Creating peer for upstream (Pingora handles connection reuse)"
1037 );
1038 let peer = self.create_peer(&selection)?;
1039
1040 debug!(
1041 upstream_id = %self.id,
1042 target = %selection.address,
1043 attempt = attempts,
1044 "Selected upstream peer"
1045 );
1046
1047 self.stats.successes.fetch_add(1, Ordering::Relaxed);
1048 return Ok(peer);
1049 }
1050
1051 self.stats.failures.fetch_add(1, Ordering::Relaxed);
1052 error!(
1053 upstream_id = %self.id,
1054 attempts = attempts,
1055 max_attempts = max_attempts,
1056 "Failed to select upstream after max attempts"
1057 );
1058 Err(SentinelError::upstream(
1059 self.id.to_string(),
1060 "Failed to select upstream after max attempts",
1061 ))
1062 }
1063
1064 fn create_peer(&self, selection: &TargetSelection) -> SentinelResult<HttpPeer> {
1070 let sni_hostname = self.tls_sni.clone().unwrap_or_else(|| {
1072 selection
1074 .address
1075 .split(':')
1076 .next()
1077 .unwrap_or(&selection.address)
1078 .to_string()
1079 });
1080
1081 let resolved_address = selection
1084 .address
1085 .to_socket_addrs()
1086 .map_err(|e| {
1087 error!(
1088 upstream = %self.id,
1089 address = %selection.address,
1090 error = %e,
1091 "Failed to resolve upstream address"
1092 );
1093 SentinelError::Upstream {
1094 upstream: self.id.to_string(),
1095 message: format!("DNS resolution failed for {}: {}", selection.address, e),
1096 retryable: true,
1097 source: None,
1098 }
1099 })?
1100 .next()
1101 .ok_or_else(|| {
1102 error!(
1103 upstream = %self.id,
1104 address = %selection.address,
1105 "No addresses returned from DNS resolution"
1106 );
1107 SentinelError::Upstream {
1108 upstream: self.id.to_string(),
1109 message: format!("No addresses for {}", selection.address),
1110 retryable: true,
1111 source: None,
1112 }
1113 })?;
1114
1115 let mut peer = HttpPeer::new(resolved_address, self.tls_enabled, sni_hostname.clone());
1117
1118 peer.options.idle_timeout = Some(self.pool_config.idle_timeout);
1122
1123 peer.options.connection_timeout = Some(self.pool_config.connection_timeout);
1125 peer.options.total_connection_timeout = Some(Duration::from_secs(10));
1126
1127 peer.options.read_timeout = Some(self.pool_config.read_timeout);
1129 peer.options.write_timeout = Some(self.pool_config.write_timeout);
1130
1131 peer.options.tcp_keepalive = Some(pingora::protocols::TcpKeepalive {
1133 idle: Duration::from_secs(60),
1134 interval: Duration::from_secs(10),
1135 count: 3,
1136 #[cfg(target_os = "linux")]
1138 user_timeout: Duration::from_secs(60),
1139 });
1140
1141 if self.tls_enabled {
1143 let alpn = match (self.http_version.min_version, self.http_version.max_version) {
1145 (2, _) => {
1146 pingora::upstreams::peer::ALPN::H2
1148 }
1149 (1, 2) | (_, 2) => {
1150 pingora::upstreams::peer::ALPN::H2H1
1152 }
1153 _ => {
1154 pingora::upstreams::peer::ALPN::H1
1156 }
1157 };
1158 peer.options.alpn = alpn;
1159
1160 if let Some(ref tls_config) = self.tls_config {
1162 if tls_config.insecure_skip_verify {
1164 peer.options.verify_cert = false;
1165 peer.options.verify_hostname = false;
1166 warn!(
1167 upstream_id = %self.id,
1168 target = %selection.address,
1169 "TLS certificate verification DISABLED (insecure_skip_verify=true)"
1170 );
1171 }
1172
1173 if let Some(ref sni) = tls_config.sni {
1175 peer.options.alternative_cn = Some(sni.clone());
1176 trace!(
1177 upstream_id = %self.id,
1178 target = %selection.address,
1179 alternative_cn = %sni,
1180 "Set alternative CN for TLS verification"
1181 );
1182 }
1183
1184 if let (Some(cert_path), Some(key_path)) =
1186 (&tls_config.client_cert, &tls_config.client_key)
1187 {
1188 match crate::tls::load_client_cert_key(cert_path, key_path) {
1189 Ok(cert_key) => {
1190 peer.client_cert_key = Some(cert_key);
1191 info!(
1192 upstream_id = %self.id,
1193 target = %selection.address,
1194 cert_path = ?cert_path,
1195 "mTLS client certificate configured"
1196 );
1197 }
1198 Err(e) => {
1199 error!(
1200 upstream_id = %self.id,
1201 target = %selection.address,
1202 error = %e,
1203 "Failed to load mTLS client certificate"
1204 );
1205 return Err(SentinelError::Tls {
1206 message: format!("Failed to load client certificate: {}", e),
1207 source: None,
1208 });
1209 }
1210 }
1211 }
1212 }
1213
1214 trace!(
1215 upstream_id = %self.id,
1216 target = %selection.address,
1217 alpn = ?peer.options.alpn,
1218 min_version = self.http_version.min_version,
1219 max_version = self.http_version.max_version,
1220 verify_cert = peer.options.verify_cert,
1221 verify_hostname = peer.options.verify_hostname,
1222 "Configured ALPN and TLS options for HTTP version negotiation"
1223 );
1224 }
1225
1226 if self.http_version.max_version >= 2 {
1228 if !self.http_version.h2_ping_interval.is_zero() {
1230 peer.options.h2_ping_interval = Some(self.http_version.h2_ping_interval);
1231 trace!(
1232 upstream_id = %self.id,
1233 target = %selection.address,
1234 h2_ping_interval_secs = self.http_version.h2_ping_interval.as_secs(),
1235 "Configured H2 ping interval"
1236 );
1237 }
1238 }
1239
1240 trace!(
1241 upstream_id = %self.id,
1242 target = %selection.address,
1243 tls = self.tls_enabled,
1244 sni = %sni_hostname,
1245 idle_timeout_secs = self.pool_config.idle_timeout.as_secs(),
1246 http_max_version = self.http_version.max_version,
1247 "Created peer with Pingora connection pooling enabled"
1248 );
1249
1250 Ok(peer)
1251 }
1252
1253 pub async fn report_result(&self, target: &str, success: bool) {
1255 trace!(
1256 upstream_id = %self.id,
1257 target = %target,
1258 success = success,
1259 "Reporting connection result"
1260 );
1261
1262 if success {
1263 if let Some(breaker) = self.circuit_breakers.read().await.get(target) {
1264 breaker.record_success().await;
1265 trace!(
1266 upstream_id = %self.id,
1267 target = %target,
1268 "Recorded success in circuit breaker"
1269 );
1270 }
1271 self.load_balancer.report_health(target, true).await;
1272 } else {
1273 if let Some(breaker) = self.circuit_breakers.read().await.get(target) {
1274 breaker.record_failure().await;
1275 debug!(
1276 upstream_id = %self.id,
1277 target = %target,
1278 "Recorded failure in circuit breaker"
1279 );
1280 }
1281 self.load_balancer.report_health(target, false).await;
1282 self.stats.failures.fetch_add(1, Ordering::Relaxed);
1283 warn!(
1284 upstream_id = %self.id,
1285 target = %target,
1286 "Connection failure reported for target"
1287 );
1288 }
1289 }
1290
1291 pub async fn report_result_with_latency(
1297 &self,
1298 target: &str,
1299 success: bool,
1300 latency: Option<Duration>,
1301 ) {
1302 trace!(
1303 upstream_id = %self.id,
1304 target = %target,
1305 success = success,
1306 latency_ms = latency.map(|l| l.as_millis() as u64),
1307 "Reporting result with latency for adaptive LB"
1308 );
1309
1310 if success {
1312 if let Some(breaker) = self.circuit_breakers.read().await.get(target) {
1313 breaker.record_success().await;
1314 }
1315 } else {
1316 if let Some(breaker) = self.circuit_breakers.read().await.get(target) {
1317 breaker.record_failure().await;
1318 }
1319 self.stats.failures.fetch_add(1, Ordering::Relaxed);
1320 }
1321
1322 self.load_balancer
1324 .report_result_with_latency(target, success, latency)
1325 .await;
1326 }
1327
1328 pub fn stats(&self) -> &PoolStats {
1330 &self.stats
1331 }
1332
1333 pub fn id(&self) -> &UpstreamId {
1335 &self.id
1336 }
1337
1338 pub fn target_count(&self) -> usize {
1340 self.targets.len()
1341 }
1342
1343 pub fn pool_config(&self) -> PoolConfigSnapshot {
1345 PoolConfigSnapshot {
1346 max_connections: self.pool_config.max_connections,
1347 max_idle: self.pool_config.max_idle,
1348 idle_timeout_secs: self.pool_config.idle_timeout.as_secs(),
1349 max_lifetime_secs: self.pool_config.max_lifetime.map(|d| d.as_secs()),
1350 connection_timeout_secs: self.pool_config.connection_timeout.as_secs(),
1351 read_timeout_secs: self.pool_config.read_timeout.as_secs(),
1352 write_timeout_secs: self.pool_config.write_timeout.as_secs(),
1353 }
1354 }
1355
1356 pub async fn has_healthy_targets(&self) -> bool {
1360 let healthy = self.load_balancer.healthy_targets().await;
1361 !healthy.is_empty()
1362 }
1363
1364 pub async fn select_shadow_target(
1369 &self,
1370 context: Option<&RequestContext>,
1371 ) -> SentinelResult<ShadowTarget> {
1372 let selection = self.load_balancer.select(context).await?;
1374
1375 let breakers = self.circuit_breakers.read().await;
1377 if let Some(breaker) = breakers.get(&selection.address) {
1378 if !breaker.is_closed().await {
1379 return Err(SentinelError::upstream(
1380 self.id.to_string(),
1381 "Circuit breaker is open for shadow target",
1382 ));
1383 }
1384 }
1385
1386 let (host, port) = if selection.address.contains(':') {
1388 let parts: Vec<&str> = selection.address.rsplitn(2, ':').collect();
1389 if parts.len() == 2 {
1390 (
1391 parts[1].to_string(),
1392 parts[0].parse::<u16>().unwrap_or(if self.tls_enabled { 443 } else { 80 }),
1393 )
1394 } else {
1395 (selection.address.clone(), if self.tls_enabled { 443 } else { 80 })
1396 }
1397 } else {
1398 (selection.address.clone(), if self.tls_enabled { 443 } else { 80 })
1399 };
1400
1401 Ok(ShadowTarget {
1402 scheme: if self.tls_enabled { "https" } else { "http" }.to_string(),
1403 host,
1404 port,
1405 sni: self.tls_sni.clone(),
1406 })
1407 }
1408
1409 pub fn is_tls_enabled(&self) -> bool {
1411 self.tls_enabled
1412 }
1413
1414 pub async fn shutdown(&self) {
1418 info!(
1419 upstream_id = %self.id,
1420 target_count = self.targets.len(),
1421 total_requests = self.stats.requests.load(Ordering::Relaxed),
1422 total_successes = self.stats.successes.load(Ordering::Relaxed),
1423 total_failures = self.stats.failures.load(Ordering::Relaxed),
1424 "Shutting down upstream pool"
1425 );
1426 debug!(upstream_id = %self.id, "Upstream pool shutdown complete");
1428 }
1429}