1use async_trait::async_trait;
15use std::collections::HashMap;
16use std::sync::atomic::{AtomicUsize, Ordering};
17use std::sync::Arc;
18use tokio::sync::RwLock;
19use tracing::{debug, trace, warn};
20
21use sentinel_common::errors::{SentinelError, SentinelResult};
22
23use super::{LoadBalancer, RequestContext, TargetSelection, UpstreamTarget};
24
25#[derive(Debug, Clone)]
27pub struct WeightedLeastConnConfig {
28 pub min_weight: u32,
30 pub tie_breaker: TieBreakerStrategy,
32}
33
34impl Default for WeightedLeastConnConfig {
35 fn default() -> Self {
36 Self {
37 min_weight: 1,
38 tie_breaker: TieBreakerStrategy::HigherWeight,
39 }
40 }
41}
42
43#[derive(Debug, Clone, Copy, Default)]
45pub enum TieBreakerStrategy {
46 #[default]
48 HigherWeight,
49 FewerConnections,
51 RoundRobin,
53}
54
55pub struct WeightedLeastConnBalancer {
57 targets: Vec<UpstreamTarget>,
59 connections: Arc<RwLock<HashMap<String, usize>>>,
61 health_status: Arc<RwLock<HashMap<String, bool>>>,
63 tie_breaker_counter: AtomicUsize,
65 config: WeightedLeastConnConfig,
67}
68
69impl WeightedLeastConnBalancer {
70 pub fn new(targets: Vec<UpstreamTarget>, config: WeightedLeastConnConfig) -> Self {
72 let mut health_status = HashMap::new();
73 let mut connections = HashMap::new();
74
75 for target in &targets {
76 let addr = target.full_address();
77 health_status.insert(addr.clone(), true);
78 connections.insert(addr, 0);
79 }
80
81 Self {
82 targets,
83 connections: Arc::new(RwLock::new(connections)),
84 health_status: Arc::new(RwLock::new(health_status)),
85 tie_breaker_counter: AtomicUsize::new(0),
86 config,
87 }
88 }
89
90 fn calculate_score(&self, connections: usize, weight: u32) -> f64 {
93 let effective_weight = weight.max(self.config.min_weight) as f64;
94 connections as f64 / effective_weight
95 }
96
97 fn break_tie<'a>(
99 &self,
100 candidates: &[(&'a UpstreamTarget, usize)],
101 ) -> Option<&'a UpstreamTarget> {
102 if candidates.is_empty() {
103 return None;
104 }
105 if candidates.len() == 1 {
106 return Some(candidates[0].0);
107 }
108
109 match self.config.tie_breaker {
110 TieBreakerStrategy::HigherWeight => {
111 candidates.iter().max_by_key(|(t, _)| t.weight).map(|(t, _)| *t)
112 }
113 TieBreakerStrategy::FewerConnections => {
114 candidates.iter().min_by_key(|(_, c)| *c).map(|(t, _)| *t)
115 }
116 TieBreakerStrategy::RoundRobin => {
117 let idx = self.tie_breaker_counter.fetch_add(1, Ordering::Relaxed) % candidates.len();
118 Some(candidates[idx].0)
119 }
120 }
121 }
122}
123
124#[async_trait]
125impl LoadBalancer for WeightedLeastConnBalancer {
126 async fn select(&self, _context: Option<&RequestContext>) -> SentinelResult<TargetSelection> {
127 trace!(
128 total_targets = self.targets.len(),
129 algorithm = "weighted_least_conn",
130 "Selecting upstream target"
131 );
132
133 let health = self.health_status.read().await;
134 let conns = self.connections.read().await;
135
136 let scored_targets: Vec<_> = self
138 .targets
139 .iter()
140 .filter(|t| *health.get(&t.full_address()).unwrap_or(&true))
141 .map(|t| {
142 let addr = t.full_address();
143 let conn_count = *conns.get(&addr).unwrap_or(&0);
144 let score = self.calculate_score(conn_count, t.weight);
145 (t, conn_count, score)
146 })
147 .collect();
148
149 drop(health);
150
151 if scored_targets.is_empty() {
152 warn!(
153 total_targets = self.targets.len(),
154 algorithm = "weighted_least_conn",
155 "No healthy upstream targets available"
156 );
157 return Err(SentinelError::NoHealthyUpstream);
158 }
159
160 let min_score = scored_targets
162 .iter()
163 .map(|(_, _, s)| *s)
164 .fold(f64::INFINITY, f64::min);
165
166 let candidates: Vec<_> = scored_targets
168 .iter()
169 .filter(|(_, _, s)| (*s - min_score).abs() < f64::EPSILON)
170 .map(|(t, c, _)| (*t, *c))
171 .collect();
172
173 let target = self.break_tie(&candidates).ok_or(SentinelError::NoHealthyUpstream)?;
174
175 drop(conns);
177 {
178 let mut conns = self.connections.write().await;
179 *conns.entry(target.full_address()).or_insert(0) += 1;
180 }
181
182 let conn_count = *self.connections.read().await.get(&target.full_address()).unwrap_or(&0);
183 let score = self.calculate_score(conn_count, target.weight);
184
185 trace!(
186 selected_target = %target.full_address(),
187 weight = target.weight,
188 connections = conn_count,
189 score = score,
190 healthy_count = scored_targets.len(),
191 algorithm = "weighted_least_conn",
192 "Selected target via weighted least connections"
193 );
194
195 Ok(TargetSelection {
196 address: target.full_address(),
197 weight: target.weight,
198 metadata: HashMap::new(),
199 })
200 }
201
202 async fn release(&self, selection: &TargetSelection) {
203 let mut conns = self.connections.write().await;
204 if let Some(count) = conns.get_mut(&selection.address) {
205 *count = count.saturating_sub(1);
206 trace!(
207 target = %selection.address,
208 connections = *count,
209 algorithm = "weighted_least_conn",
210 "Released connection"
211 );
212 }
213 }
214
215 async fn report_health(&self, address: &str, healthy: bool) {
216 trace!(
217 target = %address,
218 healthy = healthy,
219 algorithm = "weighted_least_conn",
220 "Updating target health status"
221 );
222 self.health_status
223 .write()
224 .await
225 .insert(address.to_string(), healthy);
226 }
227
228 async fn healthy_targets(&self) -> Vec<String> {
229 self.health_status
230 .read()
231 .await
232 .iter()
233 .filter_map(|(addr, &healthy)| if healthy { Some(addr.clone()) } else { None })
234 .collect()
235 }
236}
237
238#[cfg(test)]
239mod tests {
240 use super::*;
241
242 fn make_weighted_targets() -> Vec<UpstreamTarget> {
243 vec![
244 UpstreamTarget::new("backend-small", 8080, 50), UpstreamTarget::new("backend-medium", 8080, 100), UpstreamTarget::new("backend-large", 8080, 200), ]
248 }
249
250 #[tokio::test]
251 async fn test_prefers_higher_weight_when_empty() {
252 let targets = make_weighted_targets();
253 let balancer = WeightedLeastConnBalancer::new(targets, WeightedLeastConnConfig::default());
254
255 let selection = balancer.select(None).await.unwrap();
257 assert_eq!(selection.address, "backend-large:8080");
258 }
259
260 #[tokio::test]
261 async fn test_weighted_connection_ratio() {
262 let targets = make_weighted_targets();
263 let balancer = WeightedLeastConnBalancer::new(targets, WeightedLeastConnConfig::default());
264
265 {
267 let mut conns = balancer.connections.write().await;
268 conns.insert("backend-small:8080".to_string(), 5); conns.insert("backend-medium:8080".to_string(), 10); conns.insert("backend-large:8080".to_string(), 20); }
272
273 let selection = balancer.select(None).await.unwrap();
275 assert_eq!(selection.address, "backend-large:8080");
276 }
277
278 #[tokio::test]
279 async fn test_selects_lower_ratio() {
280 let targets = make_weighted_targets();
281 let balancer = WeightedLeastConnBalancer::new(targets, WeightedLeastConnConfig::default());
282
283 {
285 let mut conns = balancer.connections.write().await;
286 conns.insert("backend-small:8080".to_string(), 10); conns.insert("backend-medium:8080".to_string(), 15); conns.insert("backend-large:8080".to_string(), 20); }
290
291 let selection = balancer.select(None).await.unwrap();
292 assert_eq!(selection.address, "backend-large:8080");
293 }
294
295 #[tokio::test]
296 async fn test_selects_small_when_others_overloaded() {
297 let targets = make_weighted_targets();
298 let balancer = WeightedLeastConnBalancer::new(targets, WeightedLeastConnConfig::default());
299
300 {
302 let mut conns = balancer.connections.write().await;
303 conns.insert("backend-small:8080".to_string(), 2); conns.insert("backend-medium:8080".to_string(), 20); conns.insert("backend-large:8080".to_string(), 50); }
307
308 let selection = balancer.select(None).await.unwrap();
309 assert_eq!(selection.address, "backend-small:8080");
310 }
311
312 #[tokio::test]
313 async fn test_connection_tracking() {
314 let targets = vec![UpstreamTarget::new("backend", 8080, 100)];
315 let balancer = WeightedLeastConnBalancer::new(targets, WeightedLeastConnConfig::default());
316
317 let selection1 = balancer.select(None).await.unwrap();
319 let selection2 = balancer.select(None).await.unwrap();
320
321 {
322 let conns = balancer.connections.read().await;
323 assert_eq!(*conns.get("backend:8080").unwrap(), 2);
324 }
325
326 balancer.release(&selection1).await;
328
329 {
330 let conns = balancer.connections.read().await;
331 assert_eq!(*conns.get("backend:8080").unwrap(), 1);
332 }
333
334 balancer.release(&selection2).await;
335
336 {
337 let conns = balancer.connections.read().await;
338 assert_eq!(*conns.get("backend:8080").unwrap(), 0);
339 }
340 }
341
342 #[tokio::test]
343 async fn test_fewer_connections_tie_breaker() {
344 let targets = vec![
345 UpstreamTarget::new("backend-a", 8080, 100),
346 UpstreamTarget::new("backend-b", 8080, 100),
347 ];
348 let config = WeightedLeastConnConfig {
349 min_weight: 1,
350 tie_breaker: TieBreakerStrategy::FewerConnections,
351 };
352 let balancer = WeightedLeastConnBalancer::new(targets, config);
353
354 {
356 let mut conns = balancer.connections.write().await;
357 conns.insert("backend-a:8080".to_string(), 5);
358 conns.insert("backend-b:8080".to_string(), 3); }
360
361 {
363 let mut conns = balancer.connections.write().await;
364 conns.insert("backend-a:8080".to_string(), 5);
365 conns.insert("backend-b:8080".to_string(), 5);
366 }
367
368 }
371
372 #[tokio::test]
373 async fn test_respects_health_status() {
374 let targets = make_weighted_targets();
375 let balancer = WeightedLeastConnBalancer::new(targets, WeightedLeastConnConfig::default());
376
377 balancer.report_health("backend-large:8080", false).await;
379
380 for _ in 0..10 {
382 let selection = balancer.select(None).await.unwrap();
383 assert_ne!(selection.address, "backend-large:8080");
384 }
385 }
386}