1use crate::discovery::ServiceInfo;
9use crate::error::{Result, UmicpError};
10use parking_lot::RwLock;
11use rand::Rng;
12use std::collections::HashMap;
13use std::sync::Arc;
14use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
15
16#[derive(Debug, Clone, Copy, PartialEq, Eq)]
18pub enum LoadBalancingStrategy {
19 RoundRobin,
21 Random,
23 LeastConnections,
25 Weighted,
27}
28
29#[derive(Debug, Clone)]
31pub struct BackendEndpoint {
32 pub id: String,
34 pub address: String,
36 pub weight: u32,
38 active_connections: Arc<AtomicUsize>,
40 total_requests: Arc<AtomicU64>,
42 pub healthy: bool,
44}
45
46impl BackendEndpoint {
47 pub fn new(id: String, address: String) -> Self {
49 Self {
50 id,
51 address,
52 weight: 1,
53 active_connections: Arc::new(AtomicUsize::new(0)),
54 total_requests: Arc::new(AtomicU64::new(0)),
55 healthy: true,
56 }
57 }
58
59 pub fn with_weight(mut self, weight: u32) -> Self {
61 self.weight = weight;
62 self
63 }
64
65 pub fn active_connections(&self) -> usize {
67 self.active_connections.load(Ordering::Relaxed)
68 }
69
70 pub fn total_requests(&self) -> u64 {
72 self.total_requests.load(Ordering::Relaxed)
73 }
74
75 pub(crate) fn increment_connections(&self) {
77 self.active_connections.fetch_add(1, Ordering::Relaxed);
78 self.total_requests.fetch_add(1, Ordering::Relaxed);
79 }
80
81 pub(crate) fn decrement_connections(&self) {
83 self.active_connections.fetch_sub(1, Ordering::Relaxed);
84 }
85}
86
87impl From<ServiceInfo> for BackendEndpoint {
88 fn from(service: ServiceInfo) -> Self {
89 let weight = service
90 .metadata
91 .get("weight")
92 .and_then(|w| w.parse::<u32>().ok())
93 .unwrap_or(1);
94
95 BackendEndpoint::new(service.service_id, service.address).with_weight(weight)
96 }
97}
98
99pub struct LoadBalancer {
101 strategy: LoadBalancingStrategy,
103 endpoints: Arc<RwLock<Vec<BackendEndpoint>>>,
105 current_index: Arc<AtomicUsize>,
107}
108
109impl LoadBalancer {
110 pub fn new(strategy: LoadBalancingStrategy) -> Self {
112 Self {
113 strategy,
114 endpoints: Arc::new(RwLock::new(Vec::new())),
115 current_index: Arc::new(AtomicUsize::new(0)),
116 }
117 }
118
119 pub fn add_endpoint(&self, endpoint: BackendEndpoint) {
121 self.endpoints.write().push(endpoint);
122 }
123
124 pub fn remove_endpoint(&self, id: &str) -> bool {
126 let mut endpoints = self.endpoints.write();
127 let len_before = endpoints.len();
128 endpoints.retain(|e| e.id != id);
129 endpoints.len() < len_before
130 }
131
132 pub fn get_endpoint(&self, id: &str) -> Option<BackendEndpoint> {
134 self.endpoints
135 .read()
136 .iter()
137 .find(|e| e.id == id)
138 .cloned()
139 }
140
141 pub fn set_endpoint_health(&self, id: &str, healthy: bool) {
143 if let Some(endpoint) = self.endpoints.write().iter_mut().find(|e| e.id == id) {
144 endpoint.healthy = healthy;
145 }
146 }
147
148 pub fn get_endpoints(&self) -> Vec<BackendEndpoint> {
150 self.endpoints.read().clone()
151 }
152
153 pub fn get_healthy_endpoints(&self) -> Vec<BackendEndpoint> {
155 self.endpoints
156 .read()
157 .iter()
158 .filter(|e| e.healthy)
159 .cloned()
160 .collect()
161 }
162
163 pub fn select(&self) -> Result<BackendEndpoint> {
165 let healthy_endpoints = self.get_healthy_endpoints();
166
167 if healthy_endpoints.is_empty() {
168 return Err(UmicpError::transport("No healthy endpoints available".to_string()));
169 }
170
171 let endpoint = match self.strategy {
172 LoadBalancingStrategy::RoundRobin => self.select_round_robin(&healthy_endpoints),
173 LoadBalancingStrategy::Random => self.select_random(&healthy_endpoints),
174 LoadBalancingStrategy::LeastConnections => self.select_least_connections(&healthy_endpoints),
175 LoadBalancingStrategy::Weighted => self.select_weighted(&healthy_endpoints),
176 }?;
177
178 endpoint.increment_connections();
180
181 Ok(endpoint)
182 }
183
184 pub fn release(&self, endpoint_id: &str) {
186 if let Some(endpoint) = self.endpoints.read().iter().find(|e| e.id == endpoint_id) {
187 endpoint.decrement_connections();
188 }
189 }
190
191 fn select_round_robin(&self, endpoints: &[BackendEndpoint]) -> Result<BackendEndpoint> {
193 if endpoints.is_empty() {
194 return Err(UmicpError::transport("No endpoints available".to_string()));
195 }
196
197 let index = self.current_index.fetch_add(1, Ordering::Relaxed) % endpoints.len();
198 Ok(endpoints[index].clone())
199 }
200
201 fn select_random(&self, endpoints: &[BackendEndpoint]) -> Result<BackendEndpoint> {
203 if endpoints.is_empty() {
204 return Err(UmicpError::transport("No endpoints available".to_string()));
205 }
206
207 let mut rng = rand::thread_rng();
208 let index = rng.gen_range(0..endpoints.len());
209 Ok(endpoints[index].clone())
210 }
211
212 fn select_least_connections(&self, endpoints: &[BackendEndpoint]) -> Result<BackendEndpoint> {
214 endpoints
215 .iter()
216 .min_by_key(|e| e.active_connections())
217 .cloned()
218 .ok_or_else(|| UmicpError::transport("No endpoints available".to_string()))
219 }
220
221 fn select_weighted(&self, endpoints: &[BackendEndpoint]) -> Result<BackendEndpoint> {
223 if endpoints.is_empty() {
224 return Err(UmicpError::transport("No endpoints available".to_string()));
225 }
226
227 let total_weight: u32 = endpoints.iter().map(|e| e.weight).sum();
229
230 if total_weight == 0 {
231 return self.select_round_robin(endpoints);
233 }
234
235 let mut rng = rand::thread_rng();
237 let mut random_weight = rng.gen_range(0..total_weight);
238
239 for endpoint in endpoints {
241 if random_weight < endpoint.weight {
242 return Ok(endpoint.clone());
243 }
244 random_weight -= endpoint.weight;
245 }
246
247 Ok(endpoints[0].clone())
249 }
250
251 pub fn get_stats(&self) -> LoadBalancerStats {
253 let endpoints = self.endpoints.read();
254 let total_endpoints = endpoints.len();
255 let healthy_endpoints = endpoints.iter().filter(|e| e.healthy).count();
256 let total_connections: usize = endpoints.iter().map(|e| e.active_connections()).sum();
257 let total_requests: u64 = endpoints.iter().map(|e| e.total_requests()).sum();
258
259 LoadBalancerStats {
260 strategy: self.strategy,
261 total_endpoints,
262 healthy_endpoints,
263 total_connections,
264 total_requests,
265 }
266 }
267}
268
269#[derive(Debug, Clone)]
271pub struct LoadBalancerStats {
272 pub strategy: LoadBalancingStrategy,
273 pub total_endpoints: usize,
274 pub healthy_endpoints: usize,
275 pub total_connections: usize,
276 pub total_requests: u64,
277}
278
279pub struct ConnectionGuard<'a> {
281 balancer: &'a LoadBalancer,
282 endpoint_id: String,
283}
284
285impl<'a> ConnectionGuard<'a> {
286 pub fn new(balancer: &'a LoadBalancer, endpoint_id: String) -> Self {
287 Self {
288 balancer,
289 endpoint_id,
290 }
291 }
292
293 pub fn endpoint_id(&self) -> &str {
294 &self.endpoint_id
295 }
296}
297
298impl<'a> Drop for ConnectionGuard<'a> {
299 fn drop(&mut self) {
300 self.balancer.release(&self.endpoint_id);
301 }
302}
303
304#[cfg(test)]
305mod tests {
306 use super::*;
307
308 #[test]
309 fn test_backend_endpoint_creation() {
310 let endpoint = BackendEndpoint::new(
311 "endpoint-1".to_string(),
312 "http://localhost:8080".to_string(),
313 );
314
315 assert_eq!(endpoint.id, "endpoint-1");
316 assert_eq!(endpoint.address, "http://localhost:8080");
317 assert_eq!(endpoint.weight, 1);
318 assert_eq!(endpoint.active_connections(), 0);
319 assert!(endpoint.healthy);
320 }
321
322 #[test]
323 fn test_backend_endpoint_with_weight() {
324 let endpoint = BackendEndpoint::new(
325 "endpoint-1".to_string(),
326 "http://localhost:8080".to_string(),
327 )
328 .with_weight(5);
329
330 assert_eq!(endpoint.weight, 5);
331 }
332
333 #[test]
334 fn test_round_robin() {
335 let lb = LoadBalancer::new(LoadBalancingStrategy::RoundRobin);
336
337 lb.add_endpoint(BackendEndpoint::new("ep1".to_string(), "addr1".to_string()));
338 lb.add_endpoint(BackendEndpoint::new("ep2".to_string(), "addr2".to_string()));
339 lb.add_endpoint(BackendEndpoint::new("ep3".to_string(), "addr3".to_string()));
340
341 let e1 = lb.select().unwrap();
343 let e2 = lb.select().unwrap();
344 let e3 = lb.select().unwrap();
345 let e4 = lb.select().unwrap(); assert_eq!(e1.id, "ep1");
348 assert_eq!(e2.id, "ep2");
349 assert_eq!(e3.id, "ep3");
350 assert_eq!(e4.id, "ep1"); }
352
353 #[test]
354 fn test_random() {
355 let lb = LoadBalancer::new(LoadBalancingStrategy::Random);
356
357 lb.add_endpoint(BackendEndpoint::new("ep1".to_string(), "addr1".to_string()));
358 lb.add_endpoint(BackendEndpoint::new("ep2".to_string(), "addr2".to_string()));
359
360 for _ in 0..10 {
362 let endpoint = lb.select().unwrap();
363 assert!(endpoint.id == "ep1" || endpoint.id == "ep2");
364 }
365 }
366
367 #[test]
368 fn test_least_connections() {
369 let lb = LoadBalancer::new(LoadBalancingStrategy::LeastConnections);
370
371 let ep1 = BackendEndpoint::new("ep1".to_string(), "addr1".to_string());
372 let ep2 = BackendEndpoint::new("ep2".to_string(), "addr2".to_string());
373
374 ep1.increment_connections();
376 ep1.increment_connections();
377
378 lb.add_endpoint(ep1);
379 lb.add_endpoint(ep2);
380
381 let selected = lb.select().unwrap();
383 assert_eq!(selected.id, "ep2");
384 }
385
386 #[test]
387 fn test_weighted() {
388 let lb = LoadBalancer::new(LoadBalancingStrategy::Weighted);
389
390 lb.add_endpoint(BackendEndpoint::new("ep1".to_string(), "addr1".to_string()).with_weight(1));
391 lb.add_endpoint(BackendEndpoint::new("ep2".to_string(), "addr2".to_string()).with_weight(9));
392
393 let mut ep1_count = 0;
395 let mut ep2_count = 0;
396
397 for _ in 0..100 {
398 let endpoint = lb.select().unwrap();
399 if endpoint.id == "ep1" {
400 ep1_count += 1;
401 } else {
402 ep2_count += 1;
403 }
404 }
405
406 assert!(ep2_count > ep1_count);
408 }
409
410 #[test]
411 fn test_healthy_endpoints_only() {
412 let lb = LoadBalancer::new(LoadBalancingStrategy::RoundRobin);
413
414 lb.add_endpoint(BackendEndpoint::new("ep1".to_string(), "addr1".to_string()));
415 lb.add_endpoint(BackendEndpoint::new("ep2".to_string(), "addr2".to_string()));
416
417 lb.set_endpoint_health("ep1", false);
419
420 for _ in 0..5 {
422 let endpoint = lb.select().unwrap();
423 assert_eq!(endpoint.id, "ep2");
424 }
425 }
426
427 #[test]
428 fn test_no_healthy_endpoints() {
429 let lb = LoadBalancer::new(LoadBalancingStrategy::RoundRobin);
430
431 lb.add_endpoint(BackendEndpoint::new("ep1".to_string(), "addr1".to_string()));
432 lb.set_endpoint_health("ep1", false);
433
434 assert!(lb.select().is_err());
436 }
437
438 #[test]
439 fn test_connection_release() {
440 let lb = LoadBalancer::new(LoadBalancingStrategy::LeastConnections);
441 lb.add_endpoint(BackendEndpoint::new("ep1".to_string(), "addr1".to_string()));
442
443 let endpoint = lb.select().unwrap();
444 assert_eq!(endpoint.active_connections(), 1);
445
446 lb.release(&endpoint.id);
447
448 let endpoint_after = lb.get_endpoint("ep1").unwrap();
449 assert_eq!(endpoint_after.active_connections(), 0);
450 }
451
452 #[test]
453 fn test_stats() {
454 let lb = LoadBalancer::new(LoadBalancingStrategy::RoundRobin);
455
456 lb.add_endpoint(BackendEndpoint::new("ep1".to_string(), "addr1".to_string()));
457 lb.add_endpoint(BackendEndpoint::new("ep2".to_string(), "addr2".to_string()));
458 lb.set_endpoint_health("ep2", false);
459
460 let _ = lb.select(); let stats = lb.get_stats();
463 assert_eq!(stats.total_endpoints, 2);
464 assert_eq!(stats.healthy_endpoints, 1);
465 assert_eq!(stats.total_connections, 1);
466 assert_eq!(stats.total_requests, 1);
467 }
468}
469