umicp_core/
load_balancer.rs

1/*!
2# Load Balancing
3
4Load balancer for distributing requests across multiple service endpoints.
5Supports multiple strategies: Round Robin, Random, Least Connections, and Weighted.
6*/
7
8use crate::discovery::ServiceInfo;
9use crate::error::{Result, UmicpError};
10use parking_lot::RwLock;
11use rand::Rng;
12use std::collections::HashMap;
13use std::sync::Arc;
14use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
15
16/// Load balancing strategy
17#[derive(Debug, Clone, Copy, PartialEq, Eq)]
18pub enum LoadBalancingStrategy {
19    /// Round-robin distribution
20    RoundRobin,
21    /// Random selection
22    Random,
23    /// Least connections (select endpoint with fewest active connections)
24    LeastConnections,
25    /// Weighted distribution (based on endpoint weights)
26    Weighted,
27}
28
29/// Backend endpoint information
30#[derive(Debug, Clone)]
31pub struct BackendEndpoint {
32    /// Endpoint ID
33    pub id: String,
34    /// Endpoint address
35    pub address: String,
36    /// Endpoint weight (for weighted strategy, default: 1)
37    pub weight: u32,
38    /// Active connections count
39    active_connections: Arc<AtomicUsize>,
40    /// Total requests served
41    total_requests: Arc<AtomicU64>,
42    /// Whether endpoint is healthy
43    pub healthy: bool,
44}
45
46impl BackendEndpoint {
47    /// Create new backend endpoint
48    pub fn new(id: String, address: String) -> Self {
49        Self {
50            id,
51            address,
52            weight: 1,
53            active_connections: Arc::new(AtomicUsize::new(0)),
54            total_requests: Arc::new(AtomicU64::new(0)),
55            healthy: true,
56        }
57    }
58
59    /// Create with weight
60    pub fn with_weight(mut self, weight: u32) -> Self {
61        self.weight = weight;
62        self
63    }
64
65    /// Get active connections count
66    pub fn active_connections(&self) -> usize {
67        self.active_connections.load(Ordering::Relaxed)
68    }
69
70    /// Get total requests served
71    pub fn total_requests(&self) -> u64 {
72        self.total_requests.load(Ordering::Relaxed)
73    }
74
75    /// Increment active connections
76    pub(crate) fn increment_connections(&self) {
77        self.active_connections.fetch_add(1, Ordering::Relaxed);
78        self.total_requests.fetch_add(1, Ordering::Relaxed);
79    }
80
81    /// Decrement active connections
82    pub(crate) fn decrement_connections(&self) {
83        self.active_connections.fetch_sub(1, Ordering::Relaxed);
84    }
85}
86
87impl From<ServiceInfo> for BackendEndpoint {
88    fn from(service: ServiceInfo) -> Self {
89        let weight = service
90            .metadata
91            .get("weight")
92            .and_then(|w| w.parse::<u32>().ok())
93            .unwrap_or(1);
94
95        BackendEndpoint::new(service.service_id, service.address).with_weight(weight)
96    }
97}
98
99/// Load balancer
100pub struct LoadBalancer {
101    /// Balancing strategy
102    strategy: LoadBalancingStrategy,
103    /// Backend endpoints
104    endpoints: Arc<RwLock<Vec<BackendEndpoint>>>,
105    /// Current index for round-robin
106    current_index: Arc<AtomicUsize>,
107}
108
109impl LoadBalancer {
110    /// Create new load balancer with strategy
111    pub fn new(strategy: LoadBalancingStrategy) -> Self {
112        Self {
113            strategy,
114            endpoints: Arc::new(RwLock::new(Vec::new())),
115            current_index: Arc::new(AtomicUsize::new(0)),
116        }
117    }
118
119    /// Add backend endpoint
120    pub fn add_endpoint(&self, endpoint: BackendEndpoint) {
121        self.endpoints.write().push(endpoint);
122    }
123
124    /// Remove endpoint by ID
125    pub fn remove_endpoint(&self, id: &str) -> bool {
126        let mut endpoints = self.endpoints.write();
127        let len_before = endpoints.len();
128        endpoints.retain(|e| e.id != id);
129        endpoints.len() < len_before
130    }
131
132    /// Get endpoint by ID
133    pub fn get_endpoint(&self, id: &str) -> Option<BackendEndpoint> {
134        self.endpoints
135            .read()
136            .iter()
137            .find(|e| e.id == id)
138            .cloned()
139    }
140
141    /// Mark endpoint as healthy/unhealthy
142    pub fn set_endpoint_health(&self, id: &str, healthy: bool) {
143        if let Some(endpoint) = self.endpoints.write().iter_mut().find(|e| e.id == id) {
144            endpoint.healthy = healthy;
145        }
146    }
147
148    /// Get all endpoints
149    pub fn get_endpoints(&self) -> Vec<BackendEndpoint> {
150        self.endpoints.read().clone()
151    }
152
153    /// Get healthy endpoints only
154    pub fn get_healthy_endpoints(&self) -> Vec<BackendEndpoint> {
155        self.endpoints
156            .read()
157            .iter()
158            .filter(|e| e.healthy)
159            .cloned()
160            .collect()
161    }
162
163    /// Select next endpoint based on strategy
164    pub fn select(&self) -> Result<BackendEndpoint> {
165        let healthy_endpoints = self.get_healthy_endpoints();
166
167        if healthy_endpoints.is_empty() {
168            return Err(UmicpError::transport("No healthy endpoints available".to_string()));
169        }
170
171        let endpoint = match self.strategy {
172            LoadBalancingStrategy::RoundRobin => self.select_round_robin(&healthy_endpoints),
173            LoadBalancingStrategy::Random => self.select_random(&healthy_endpoints),
174            LoadBalancingStrategy::LeastConnections => self.select_least_connections(&healthy_endpoints),
175            LoadBalancingStrategy::Weighted => self.select_weighted(&healthy_endpoints),
176        }?;
177
178        // Increment connection count
179        endpoint.increment_connections();
180
181        Ok(endpoint)
182    }
183
184    /// Release endpoint (decrement connection count)
185    pub fn release(&self, endpoint_id: &str) {
186        if let Some(endpoint) = self.endpoints.read().iter().find(|e| e.id == endpoint_id) {
187            endpoint.decrement_connections();
188        }
189    }
190
191    /// Round-robin selection
192    fn select_round_robin(&self, endpoints: &[BackendEndpoint]) -> Result<BackendEndpoint> {
193        if endpoints.is_empty() {
194            return Err(UmicpError::transport("No endpoints available".to_string()));
195        }
196
197        let index = self.current_index.fetch_add(1, Ordering::Relaxed) % endpoints.len();
198        Ok(endpoints[index].clone())
199    }
200
201    /// Random selection
202    fn select_random(&self, endpoints: &[BackendEndpoint]) -> Result<BackendEndpoint> {
203        if endpoints.is_empty() {
204            return Err(UmicpError::transport("No endpoints available".to_string()));
205        }
206
207        let mut rng = rand::thread_rng();
208        let index = rng.gen_range(0..endpoints.len());
209        Ok(endpoints[index].clone())
210    }
211
212    /// Least connections selection
213    fn select_least_connections(&self, endpoints: &[BackendEndpoint]) -> Result<BackendEndpoint> {
214        endpoints
215            .iter()
216            .min_by_key(|e| e.active_connections())
217            .cloned()
218            .ok_or_else(|| UmicpError::transport("No endpoints available".to_string()))
219    }
220
221    /// Weighted selection
222    fn select_weighted(&self, endpoints: &[BackendEndpoint]) -> Result<BackendEndpoint> {
223        if endpoints.is_empty() {
224            return Err(UmicpError::transport("No endpoints available".to_string()));
225        }
226
227        // Calculate total weight
228        let total_weight: u32 = endpoints.iter().map(|e| e.weight).sum();
229
230        if total_weight == 0 {
231            // Fall back to round-robin if all weights are 0
232            return self.select_round_robin(endpoints);
233        }
234
235        // Generate random number in range [0, total_weight)
236        let mut rng = rand::thread_rng();
237        let mut random_weight = rng.gen_range(0..total_weight);
238
239        // Select endpoint based on weight
240        for endpoint in endpoints {
241            if random_weight < endpoint.weight {
242                return Ok(endpoint.clone());
243            }
244            random_weight -= endpoint.weight;
245        }
246
247        // Fallback (should not reach here)
248        Ok(endpoints[0].clone())
249    }
250
251    /// Get load balancer statistics
252    pub fn get_stats(&self) -> LoadBalancerStats {
253        let endpoints = self.endpoints.read();
254        let total_endpoints = endpoints.len();
255        let healthy_endpoints = endpoints.iter().filter(|e| e.healthy).count();
256        let total_connections: usize = endpoints.iter().map(|e| e.active_connections()).sum();
257        let total_requests: u64 = endpoints.iter().map(|e| e.total_requests()).sum();
258
259        LoadBalancerStats {
260            strategy: self.strategy,
261            total_endpoints,
262            healthy_endpoints,
263            total_connections,
264            total_requests,
265        }
266    }
267}
268
269/// Load balancer statistics
270#[derive(Debug, Clone)]
271pub struct LoadBalancerStats {
272    pub strategy: LoadBalancingStrategy,
273    pub total_endpoints: usize,
274    pub healthy_endpoints: usize,
275    pub total_connections: usize,
276    pub total_requests: u64,
277}
278
279/// RAII guard for automatic connection release
280pub struct ConnectionGuard<'a> {
281    balancer: &'a LoadBalancer,
282    endpoint_id: String,
283}
284
285impl<'a> ConnectionGuard<'a> {
286    pub fn new(balancer: &'a LoadBalancer, endpoint_id: String) -> Self {
287        Self {
288            balancer,
289            endpoint_id,
290        }
291    }
292
293    pub fn endpoint_id(&self) -> &str {
294        &self.endpoint_id
295    }
296}
297
298impl<'a> Drop for ConnectionGuard<'a> {
299    fn drop(&mut self) {
300        self.balancer.release(&self.endpoint_id);
301    }
302}
303
304#[cfg(test)]
305mod tests {
306    use super::*;
307
308    #[test]
309    fn test_backend_endpoint_creation() {
310        let endpoint = BackendEndpoint::new(
311            "endpoint-1".to_string(),
312            "http://localhost:8080".to_string(),
313        );
314
315        assert_eq!(endpoint.id, "endpoint-1");
316        assert_eq!(endpoint.address, "http://localhost:8080");
317        assert_eq!(endpoint.weight, 1);
318        assert_eq!(endpoint.active_connections(), 0);
319        assert!(endpoint.healthy);
320    }
321
322    #[test]
323    fn test_backend_endpoint_with_weight() {
324        let endpoint = BackendEndpoint::new(
325            "endpoint-1".to_string(),
326            "http://localhost:8080".to_string(),
327        )
328        .with_weight(5);
329
330        assert_eq!(endpoint.weight, 5);
331    }
332
333    #[test]
334    fn test_round_robin() {
335        let lb = LoadBalancer::new(LoadBalancingStrategy::RoundRobin);
336
337        lb.add_endpoint(BackendEndpoint::new("ep1".to_string(), "addr1".to_string()));
338        lb.add_endpoint(BackendEndpoint::new("ep2".to_string(), "addr2".to_string()));
339        lb.add_endpoint(BackendEndpoint::new("ep3".to_string(), "addr3".to_string()));
340
341        // Should cycle through endpoints
342        let e1 = lb.select().unwrap();
343        let e2 = lb.select().unwrap();
344        let e3 = lb.select().unwrap();
345        let e4 = lb.select().unwrap(); // Should wrap around
346
347        assert_eq!(e1.id, "ep1");
348        assert_eq!(e2.id, "ep2");
349        assert_eq!(e3.id, "ep3");
350        assert_eq!(e4.id, "ep1"); // Wrapped
351    }
352
353    #[test]
354    fn test_random() {
355        let lb = LoadBalancer::new(LoadBalancingStrategy::Random);
356
357        lb.add_endpoint(BackendEndpoint::new("ep1".to_string(), "addr1".to_string()));
358        lb.add_endpoint(BackendEndpoint::new("ep2".to_string(), "addr2".to_string()));
359
360        // Should select valid endpoints
361        for _ in 0..10 {
362            let endpoint = lb.select().unwrap();
363            assert!(endpoint.id == "ep1" || endpoint.id == "ep2");
364        }
365    }
366
367    #[test]
368    fn test_least_connections() {
369        let lb = LoadBalancer::new(LoadBalancingStrategy::LeastConnections);
370
371        let ep1 = BackendEndpoint::new("ep1".to_string(), "addr1".to_string());
372        let ep2 = BackendEndpoint::new("ep2".to_string(), "addr2".to_string());
373
374        // Simulate ep1 having more connections
375        ep1.increment_connections();
376        ep1.increment_connections();
377
378        lb.add_endpoint(ep1);
379        lb.add_endpoint(ep2);
380
381        // Should select ep2 (fewer connections)
382        let selected = lb.select().unwrap();
383        assert_eq!(selected.id, "ep2");
384    }
385
386    #[test]
387    fn test_weighted() {
388        let lb = LoadBalancer::new(LoadBalancingStrategy::Weighted);
389
390        lb.add_endpoint(BackendEndpoint::new("ep1".to_string(), "addr1".to_string()).with_weight(1));
391        lb.add_endpoint(BackendEndpoint::new("ep2".to_string(), "addr2".to_string()).with_weight(9));
392
393        // With weights 1:9, ep2 should be selected ~90% of the time
394        let mut ep1_count = 0;
395        let mut ep2_count = 0;
396
397        for _ in 0..100 {
398            let endpoint = lb.select().unwrap();
399            if endpoint.id == "ep1" {
400                ep1_count += 1;
401            } else {
402                ep2_count += 1;
403            }
404        }
405
406        // ep2 should be selected more often (rough check)
407        assert!(ep2_count > ep1_count);
408    }
409
410    #[test]
411    fn test_healthy_endpoints_only() {
412        let lb = LoadBalancer::new(LoadBalancingStrategy::RoundRobin);
413
414        lb.add_endpoint(BackendEndpoint::new("ep1".to_string(), "addr1".to_string()));
415        lb.add_endpoint(BackendEndpoint::new("ep2".to_string(), "addr2".to_string()));
416
417        // Mark ep1 as unhealthy
418        lb.set_endpoint_health("ep1", false);
419
420        // Should only select ep2
421        for _ in 0..5 {
422            let endpoint = lb.select().unwrap();
423            assert_eq!(endpoint.id, "ep2");
424        }
425    }
426
427    #[test]
428    fn test_no_healthy_endpoints() {
429        let lb = LoadBalancer::new(LoadBalancingStrategy::RoundRobin);
430
431        lb.add_endpoint(BackendEndpoint::new("ep1".to_string(), "addr1".to_string()));
432        lb.set_endpoint_health("ep1", false);
433
434        // Should return error
435        assert!(lb.select().is_err());
436    }
437
438    #[test]
439    fn test_connection_release() {
440        let lb = LoadBalancer::new(LoadBalancingStrategy::LeastConnections);
441        lb.add_endpoint(BackendEndpoint::new("ep1".to_string(), "addr1".to_string()));
442
443        let endpoint = lb.select().unwrap();
444        assert_eq!(endpoint.active_connections(), 1);
445
446        lb.release(&endpoint.id);
447
448        let endpoint_after = lb.get_endpoint("ep1").unwrap();
449        assert_eq!(endpoint_after.active_connections(), 0);
450    }
451
452    #[test]
453    fn test_stats() {
454        let lb = LoadBalancer::new(LoadBalancingStrategy::RoundRobin);
455
456        lb.add_endpoint(BackendEndpoint::new("ep1".to_string(), "addr1".to_string()));
457        lb.add_endpoint(BackendEndpoint::new("ep2".to_string(), "addr2".to_string()));
458        lb.set_endpoint_health("ep2", false);
459
460        let _ = lb.select(); // Select once
461
462        let stats = lb.get_stats();
463        assert_eq!(stats.total_endpoints, 2);
464        assert_eq!(stats.healthy_endpoints, 1);
465        assert_eq!(stats.total_connections, 1);
466        assert_eq!(stats.total_requests, 1);
467    }
468}
469