sentinel_proxy/upstream/
weighted_least_conn.rs

1//! Weighted Least Connections load balancer
2//!
3//! Combines weight-based selection with connection counting. The algorithm
4//! selects the backend with the lowest ratio of active connections to weight.
5//!
6//! Score = active_connections / weight
7//!
8//! A backend with weight 200 and 10 connections (score: 0.05) is preferred
9//! over a backend with weight 100 and 6 connections (score: 0.06).
10//!
11//! This is useful when backends have different capacities - higher weight
12//! backends can handle more concurrent connections proportionally.
13
14use async_trait::async_trait;
15use std::collections::HashMap;
16use std::sync::atomic::{AtomicUsize, Ordering};
17use std::sync::Arc;
18use tokio::sync::RwLock;
19use tracing::{debug, trace, warn};
20
21use sentinel_common::errors::{SentinelError, SentinelResult};
22
23use super::{LoadBalancer, RequestContext, TargetSelection, UpstreamTarget};
24
25/// Configuration for Weighted Least Connections
26#[derive(Debug, Clone)]
27pub struct WeightedLeastConnConfig {
28    /// Minimum weight to prevent division by zero (default: 1)
29    pub min_weight: u32,
30    /// Tie-breaker strategy when scores are equal
31    pub tie_breaker: TieBreakerStrategy,
32}
33
34impl Default for WeightedLeastConnConfig {
35    fn default() -> Self {
36        Self {
37            min_weight: 1,
38            tie_breaker: TieBreakerStrategy::HigherWeight,
39        }
40    }
41}
42
43/// Strategy for breaking ties when multiple backends have the same score
44#[derive(Debug, Clone, Copy, Default)]
45pub enum TieBreakerStrategy {
46    /// Prefer backend with higher weight (can handle more traffic)
47    #[default]
48    HigherWeight,
49    /// Prefer backend with fewer connections (more headroom)
50    FewerConnections,
51    /// Round-robin among tied backends
52    RoundRobin,
53}
54
55/// Weighted Least Connections load balancer
56pub struct WeightedLeastConnBalancer {
57    /// Target list
58    targets: Vec<UpstreamTarget>,
59    /// Active connections per target
60    connections: Arc<RwLock<HashMap<String, usize>>>,
61    /// Health status per target
62    health_status: Arc<RwLock<HashMap<String, bool>>>,
63    /// Round-robin counter for tie-breaking
64    tie_breaker_counter: AtomicUsize,
65    /// Configuration
66    config: WeightedLeastConnConfig,
67}
68
69impl WeightedLeastConnBalancer {
70    /// Create a new Weighted Least Connections balancer
71    pub fn new(targets: Vec<UpstreamTarget>, config: WeightedLeastConnConfig) -> Self {
72        let mut health_status = HashMap::new();
73        let mut connections = HashMap::new();
74
75        for target in &targets {
76            let addr = target.full_address();
77            health_status.insert(addr.clone(), true);
78            connections.insert(addr, 0);
79        }
80
81        Self {
82            targets,
83            connections: Arc::new(RwLock::new(connections)),
84            health_status: Arc::new(RwLock::new(health_status)),
85            tie_breaker_counter: AtomicUsize::new(0),
86            config,
87        }
88    }
89
90    /// Calculate the weighted connection score for a target
91    /// Lower score = better candidate
92    fn calculate_score(&self, connections: usize, weight: u32) -> f64 {
93        let effective_weight = weight.max(self.config.min_weight) as f64;
94        connections as f64 / effective_weight
95    }
96
97    /// Break ties between targets with the same score
98    fn break_tie<'a>(
99        &self,
100        candidates: &[(&'a UpstreamTarget, usize)],
101    ) -> Option<&'a UpstreamTarget> {
102        if candidates.is_empty() {
103            return None;
104        }
105        if candidates.len() == 1 {
106            return Some(candidates[0].0);
107        }
108
109        match self.config.tie_breaker {
110            TieBreakerStrategy::HigherWeight => {
111                candidates.iter().max_by_key(|(t, _)| t.weight).map(|(t, _)| *t)
112            }
113            TieBreakerStrategy::FewerConnections => {
114                candidates.iter().min_by_key(|(_, c)| *c).map(|(t, _)| *t)
115            }
116            TieBreakerStrategy::RoundRobin => {
117                let idx = self.tie_breaker_counter.fetch_add(1, Ordering::Relaxed) % candidates.len();
118                Some(candidates[idx].0)
119            }
120        }
121    }
122}
123
124#[async_trait]
125impl LoadBalancer for WeightedLeastConnBalancer {
126    async fn select(&self, _context: Option<&RequestContext>) -> SentinelResult<TargetSelection> {
127        trace!(
128            total_targets = self.targets.len(),
129            algorithm = "weighted_least_conn",
130            "Selecting upstream target"
131        );
132
133        let health = self.health_status.read().await;
134        let conns = self.connections.read().await;
135
136        // Calculate scores for healthy targets
137        let scored_targets: Vec<_> = self
138            .targets
139            .iter()
140            .filter(|t| *health.get(&t.full_address()).unwrap_or(&true))
141            .map(|t| {
142                let addr = t.full_address();
143                let conn_count = *conns.get(&addr).unwrap_or(&0);
144                let score = self.calculate_score(conn_count, t.weight);
145                (t, conn_count, score)
146            })
147            .collect();
148
149        drop(health);
150
151        if scored_targets.is_empty() {
152            warn!(
153                total_targets = self.targets.len(),
154                algorithm = "weighted_least_conn",
155                "No healthy upstream targets available"
156            );
157            return Err(SentinelError::NoHealthyUpstream);
158        }
159
160        // Find minimum score
161        let min_score = scored_targets
162            .iter()
163            .map(|(_, _, s)| *s)
164            .fold(f64::INFINITY, f64::min);
165
166        // Get all targets with the minimum score (for tie-breaking)
167        let candidates: Vec<_> = scored_targets
168            .iter()
169            .filter(|(_, _, s)| (*s - min_score).abs() < f64::EPSILON)
170            .map(|(t, c, _)| (*t, *c))
171            .collect();
172
173        let target = self.break_tie(&candidates).ok_or(SentinelError::NoHealthyUpstream)?;
174
175        // Increment connection count
176        drop(conns);
177        {
178            let mut conns = self.connections.write().await;
179            *conns.entry(target.full_address()).or_insert(0) += 1;
180        }
181
182        let conn_count = *self.connections.read().await.get(&target.full_address()).unwrap_or(&0);
183        let score = self.calculate_score(conn_count, target.weight);
184
185        trace!(
186            selected_target = %target.full_address(),
187            weight = target.weight,
188            connections = conn_count,
189            score = score,
190            healthy_count = scored_targets.len(),
191            algorithm = "weighted_least_conn",
192            "Selected target via weighted least connections"
193        );
194
195        Ok(TargetSelection {
196            address: target.full_address(),
197            weight: target.weight,
198            metadata: HashMap::new(),
199        })
200    }
201
202    async fn release(&self, selection: &TargetSelection) {
203        let mut conns = self.connections.write().await;
204        if let Some(count) = conns.get_mut(&selection.address) {
205            *count = count.saturating_sub(1);
206            trace!(
207                target = %selection.address,
208                connections = *count,
209                algorithm = "weighted_least_conn",
210                "Released connection"
211            );
212        }
213    }
214
215    async fn report_health(&self, address: &str, healthy: bool) {
216        trace!(
217            target = %address,
218            healthy = healthy,
219            algorithm = "weighted_least_conn",
220            "Updating target health status"
221        );
222        self.health_status
223            .write()
224            .await
225            .insert(address.to_string(), healthy);
226    }
227
228    async fn healthy_targets(&self) -> Vec<String> {
229        self.health_status
230            .read()
231            .await
232            .iter()
233            .filter_map(|(addr, &healthy)| if healthy { Some(addr.clone()) } else { None })
234            .collect()
235    }
236}
237
238#[cfg(test)]
239mod tests {
240    use super::*;
241
242    fn make_weighted_targets() -> Vec<UpstreamTarget> {
243        vec![
244            UpstreamTarget::new("backend-small", 8080, 50),  // Low capacity
245            UpstreamTarget::new("backend-medium", 8080, 100), // Medium capacity
246            UpstreamTarget::new("backend-large", 8080, 200),  // High capacity
247        ]
248    }
249
250    #[tokio::test]
251    async fn test_prefers_higher_weight_when_empty() {
252        let targets = make_weighted_targets();
253        let balancer = WeightedLeastConnBalancer::new(targets, WeightedLeastConnConfig::default());
254
255        // With no connections, all have score 0, tie-breaker prefers higher weight
256        let selection = balancer.select(None).await.unwrap();
257        assert_eq!(selection.address, "backend-large:8080");
258    }
259
260    #[tokio::test]
261    async fn test_weighted_connection_ratio() {
262        let targets = make_weighted_targets();
263        let balancer = WeightedLeastConnBalancer::new(targets, WeightedLeastConnConfig::default());
264
265        // Add connections proportional to weight
266        {
267            let mut conns = balancer.connections.write().await;
268            conns.insert("backend-small:8080".to_string(), 5);   // 5/50 = 0.10
269            conns.insert("backend-medium:8080".to_string(), 10); // 10/100 = 0.10
270            conns.insert("backend-large:8080".to_string(), 20);  // 20/200 = 0.10
271        }
272
273        // All have same ratio, tie-breaker picks highest weight
274        let selection = balancer.select(None).await.unwrap();
275        assert_eq!(selection.address, "backend-large:8080");
276    }
277
278    #[tokio::test]
279    async fn test_selects_lower_ratio() {
280        let targets = make_weighted_targets();
281        let balancer = WeightedLeastConnBalancer::new(targets, WeightedLeastConnConfig::default());
282
283        // backend-large has better ratio
284        {
285            let mut conns = balancer.connections.write().await;
286            conns.insert("backend-small:8080".to_string(), 10);  // 10/50 = 0.20
287            conns.insert("backend-medium:8080".to_string(), 15); // 15/100 = 0.15
288            conns.insert("backend-large:8080".to_string(), 20);  // 20/200 = 0.10 (best)
289        }
290
291        let selection = balancer.select(None).await.unwrap();
292        assert_eq!(selection.address, "backend-large:8080");
293    }
294
295    #[tokio::test]
296    async fn test_selects_small_when_others_overloaded() {
297        let targets = make_weighted_targets();
298        let balancer = WeightedLeastConnBalancer::new(targets, WeightedLeastConnConfig::default());
299
300        // backend-small has best ratio despite low weight
301        {
302            let mut conns = balancer.connections.write().await;
303            conns.insert("backend-small:8080".to_string(), 2);   // 2/50 = 0.04 (best)
304            conns.insert("backend-medium:8080".to_string(), 20); // 20/100 = 0.20
305            conns.insert("backend-large:8080".to_string(), 50);  // 50/200 = 0.25
306        }
307
308        let selection = balancer.select(None).await.unwrap();
309        assert_eq!(selection.address, "backend-small:8080");
310    }
311
312    #[tokio::test]
313    async fn test_connection_tracking() {
314        let targets = vec![UpstreamTarget::new("backend", 8080, 100)];
315        let balancer = WeightedLeastConnBalancer::new(targets, WeightedLeastConnConfig::default());
316
317        // Select increments connections
318        let selection1 = balancer.select(None).await.unwrap();
319        let selection2 = balancer.select(None).await.unwrap();
320
321        {
322            let conns = balancer.connections.read().await;
323            assert_eq!(*conns.get("backend:8080").unwrap(), 2);
324        }
325
326        // Release decrements connections
327        balancer.release(&selection1).await;
328
329        {
330            let conns = balancer.connections.read().await;
331            assert_eq!(*conns.get("backend:8080").unwrap(), 1);
332        }
333
334        balancer.release(&selection2).await;
335
336        {
337            let conns = balancer.connections.read().await;
338            assert_eq!(*conns.get("backend:8080").unwrap(), 0);
339        }
340    }
341
342    #[tokio::test]
343    async fn test_fewer_connections_tie_breaker() {
344        let targets = vec![
345            UpstreamTarget::new("backend-a", 8080, 100),
346            UpstreamTarget::new("backend-b", 8080, 100),
347        ];
348        let config = WeightedLeastConnConfig {
349            min_weight: 1,
350            tie_breaker: TieBreakerStrategy::FewerConnections,
351        };
352        let balancer = WeightedLeastConnBalancer::new(targets, config);
353
354        // Same weight, different connections
355        {
356            let mut conns = balancer.connections.write().await;
357            conns.insert("backend-a:8080".to_string(), 5);
358            conns.insert("backend-b:8080".to_string(), 3); // Fewer connections
359        }
360
361        // Both have score 0.05 and 0.03, but if we set them equal:
362        {
363            let mut conns = balancer.connections.write().await;
364            conns.insert("backend-a:8080".to_string(), 5);
365            conns.insert("backend-b:8080".to_string(), 5);
366        }
367
368        // With same score, fewer_connections tie-breaker should still work
369        // (but they're equal now so either is valid)
370    }
371
372    #[tokio::test]
373    async fn test_respects_health_status() {
374        let targets = make_weighted_targets();
375        let balancer = WeightedLeastConnBalancer::new(targets, WeightedLeastConnConfig::default());
376
377        // Mark large backend as unhealthy
378        balancer.report_health("backend-large:8080", false).await;
379
380        // Should not select the unhealthy backend
381        for _ in 0..10 {
382            let selection = balancer.select(None).await.unwrap();
383            assert_ne!(selection.address, "backend-large:8080");
384        }
385    }
386}