1use parking_lot::RwLock;
2use rand::Rng;
3use std::collections::HashMap;
4use std::sync::Arc;
5use std::sync::atomic::{AtomicUsize, Ordering};
6use std::time::{Duration, Instant};
7
8use crate::discovery::{DiscoveryEvent, Endpoint, ServiceDiscovery};
9use crate::error::Result;
10use crate::streaming::StreamId;
11
12#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
13pub enum ServerHealth {
14 Healthy,
15 Degraded,
16 Unhealthy,
17 #[default]
18 Unknown,
19}
20
21#[derive(Debug, Clone)]
22pub struct ServerState<S = ()> {
23 pub endpoint: Endpoint,
24 pub health: ServerHealth,
25 pub status: Option<S>,
26 pub last_update: Instant,
27 pub active_requests: usize,
28 pub total_requests: u64,
29 pub total_errors: u64,
30 pub consecutive_failures: u32,
31}
32
33impl<S> ServerState<S> {
34 pub fn new(endpoint: Endpoint) -> Self {
35 Self {
36 endpoint,
37 health: ServerHealth::Unknown,
38 status: None,
39 last_update: Instant::now(),
40 active_requests: 0,
41 total_requests: 0,
42 total_errors: 0,
43 consecutive_failures: 0,
44 }
45 }
46
47 pub fn is_available(&self) -> bool {
48 matches!(
49 self.health,
50 ServerHealth::Healthy | ServerHealth::Degraded | ServerHealth::Unknown
51 )
52 }
53
54 pub fn record_success(&mut self) {
55 self.consecutive_failures = 0;
56 self.total_requests += 1;
57 self.last_update = Instant::now();
58 }
59
60 pub fn record_failure(&mut self) {
61 self.consecutive_failures += 1;
62 self.total_errors += 1;
63 self.total_requests += 1;
64 self.last_update = Instant::now();
65 }
66}
67
68pub trait LoadBalanceStrategy: Send + Sync {
69 type Status: Clone + Send + Sync + 'static;
70
71 fn select(&self, servers: &[ServerState<Self::Status>]) -> Option<usize>;
72
73 fn update_status(&self, _server_idx: usize, _status: Self::Status) {}
74
75 fn on_success(&self, _server_idx: usize) {}
76
77 fn on_failure(&self, _server_idx: usize) {}
78
79 fn name(&self) -> &'static str;
80}
81
82pub struct RoundRobin {
83 counter: AtomicUsize,
84}
85
86impl RoundRobin {
87 pub fn new() -> Self {
88 Self {
89 counter: AtomicUsize::new(0),
90 }
91 }
92}
93
94impl Default for RoundRobin {
95 fn default() -> Self {
96 Self::new()
97 }
98}
99
100impl LoadBalanceStrategy for RoundRobin {
101 type Status = ();
102
103 fn select(&self, servers: &[ServerState<()>]) -> Option<usize> {
104 let available: Vec<_> = servers
105 .iter()
106 .enumerate()
107 .filter(|(_, s)| s.is_available())
108 .collect();
109
110 if available.is_empty() {
111 return None;
112 }
113
114 let idx = self.counter.fetch_add(1, Ordering::Relaxed);
115 Some(available[idx % available.len()].0)
116 }
117
118 fn name(&self) -> &'static str {
119 "RoundRobin"
120 }
121}
122
123pub struct Random;
124
125impl Random {
126 pub fn new() -> Self {
127 Self
128 }
129}
130
131impl Default for Random {
132 fn default() -> Self {
133 Self::new()
134 }
135}
136
137impl LoadBalanceStrategy for Random {
138 type Status = ();
139
140 fn select(&self, servers: &[ServerState<()>]) -> Option<usize> {
141 let available: Vec<_> = servers
142 .iter()
143 .enumerate()
144 .filter(|(_, s)| s.is_available())
145 .map(|(i, _)| i)
146 .collect();
147
148 if available.is_empty() {
149 return None;
150 }
151
152 let idx = rand::thread_rng().gen_range(0..available.len());
153 Some(available[idx])
154 }
155
156 fn name(&self) -> &'static str {
157 "Random"
158 }
159}
160
161pub struct LeastConnections;
162
163impl LeastConnections {
164 pub fn new() -> Self {
165 Self
166 }
167}
168
169impl Default for LeastConnections {
170 fn default() -> Self {
171 Self::new()
172 }
173}
174
175impl LoadBalanceStrategy for LeastConnections {
176 type Status = ();
177
178 fn select(&self, servers: &[ServerState<()>]) -> Option<usize> {
179 servers
180 .iter()
181 .enumerate()
182 .filter(|(_, s)| s.is_available())
183 .min_by_key(|(_, s)| s.active_requests)
184 .map(|(i, _)| i)
185 }
186
187 fn name(&self) -> &'static str {
188 "LeastConnections"
189 }
190}
191
192#[derive(Debug, Clone)]
193pub struct ServerWeight {
194 pub weight: u32,
195 pub current_weight: i32,
196}
197
198impl Default for ServerWeight {
199 fn default() -> Self {
200 Self {
201 weight: 1,
202 current_weight: 0,
203 }
204 }
205}
206
207pub struct WeightedRoundRobin {
208 weights: parking_lot::Mutex<Vec<ServerWeight>>,
209}
210
211impl WeightedRoundRobin {
212 pub fn new() -> Self {
213 Self {
214 weights: parking_lot::Mutex::new(Vec::new()),
215 }
216 }
217
218 pub fn with_weights(weights: Vec<u32>) -> Self {
219 let sw: Vec<_> = weights
220 .into_iter()
221 .map(|w| ServerWeight {
222 weight: w,
223 current_weight: 0,
224 })
225 .collect();
226 Self {
227 weights: parking_lot::Mutex::new(sw),
228 }
229 }
230
231 pub fn set_weight(&self, server_idx: usize, weight: u32) {
232 let mut weights = self.weights.lock();
233 while weights.len() <= server_idx {
234 weights.push(ServerWeight::default());
235 }
236 weights[server_idx].weight = weight;
237 }
238}
239
240impl Default for WeightedRoundRobin {
241 fn default() -> Self {
242 Self::new()
243 }
244}
245
246impl LoadBalanceStrategy for WeightedRoundRobin {
247 type Status = ();
248
249 fn select(&self, servers: &[ServerState<()>]) -> Option<usize> {
250 let mut weights = self.weights.lock();
251
252 while weights.len() < servers.len() {
253 weights.push(ServerWeight::default());
254 }
255
256 let mut total_weight = 0i32;
257 let mut best_idx = None;
258 let mut best_weight = i32::MIN;
259
260 for (i, (server, sw)) in servers.iter().zip(weights.iter_mut()).enumerate() {
261 if !server.is_available() {
262 continue;
263 }
264
265 sw.current_weight += sw.weight as i32;
266 total_weight += sw.weight as i32;
267
268 if sw.current_weight > best_weight {
269 best_weight = sw.current_weight;
270 best_idx = Some(i);
271 }
272 }
273
274 if let Some(idx) = best_idx {
275 weights[idx].current_weight -= total_weight;
276 }
277
278 best_idx
279 }
280
281 fn name(&self) -> &'static str {
282 "WeightedRoundRobin"
283 }
284}
285
286pub struct ScoreBased {
287 pub threshold: f32,
288 pub stale_timeout: Duration,
289}
290
291impl ScoreBased {
292 pub fn new() -> Self {
293 Self {
294 threshold: 0.95,
295 stale_timeout: Duration::from_secs(30),
296 }
297 }
298
299 pub fn with_threshold(mut self, threshold: f32) -> Self {
300 self.threshold = threshold;
301 self
302 }
303
304 pub fn with_stale_timeout(mut self, timeout: Duration) -> Self {
305 self.stale_timeout = timeout;
306 self
307 }
308}
309
310impl Default for ScoreBased {
311 fn default() -> Self {
312 Self::new()
313 }
314}
315
316impl LoadBalanceStrategy for ScoreBased {
317 type Status = f32;
318
319 fn select(&self, servers: &[ServerState<f32>]) -> Option<usize> {
320 servers
321 .iter()
322 .enumerate()
323 .filter(|(_, s)| {
324 if !s.is_available() {
325 return false;
326 }
327 if s.last_update.elapsed() > self.stale_timeout {
328 return true;
329 }
330 s.status.map_or(true, |score| score < self.threshold)
331 })
332 .min_by(|(_, a), (_, b)| {
333 let a_stale = a.last_update.elapsed() > self.stale_timeout;
334 let b_stale = b.last_update.elapsed() > self.stale_timeout;
335
336 match (a_stale, b_stale) {
337 (true, false) => std::cmp::Ordering::Greater,
338 (false, true) => std::cmp::Ordering::Less,
339 _ => {
340 let a_score = a.status.unwrap_or(0.0);
341 let b_score = b.status.unwrap_or(0.0);
342 a_score
343 .partial_cmp(&b_score)
344 .unwrap_or(std::cmp::Ordering::Equal)
345 }
346 }
347 })
348 .map(|(i, _)| i)
349 }
350
351 fn name(&self) -> &'static str {
352 "ScoreBased"
353 }
354}
355
356#[derive(Debug, Clone)]
357pub struct LoadBalancerConfig {
358 pub max_failures: u32,
359 pub health_check_interval: Duration,
360 pub auto_health_check: bool,
361 pub failover_enabled: bool,
362 pub max_failover_attempts: u32,
363}
364
365impl Default for LoadBalancerConfig {
366 fn default() -> Self {
367 Self {
368 max_failures: 3,
369 health_check_interval: Duration::from_secs(10),
370 auto_health_check: true,
371 failover_enabled: true,
372 max_failover_attempts: 2,
373 }
374 }
375}
376
377pub struct LoadBalancer<S: LoadBalanceStrategy> {
378 discovery: Arc<dyn ServiceDiscovery>,
379 strategy: S,
380 servers: Arc<RwLock<Vec<ServerState<S::Status>>>>,
381 stream_affinity: RwLock<HashMap<StreamId, usize>>,
382 config: LoadBalancerConfig,
383}
384
385impl<S: LoadBalanceStrategy + 'static> LoadBalancer<S> {
386 pub fn new(discovery: Arc<dyn ServiceDiscovery>, strategy: S) -> Self {
387 Self::with_config(discovery, strategy, LoadBalancerConfig::default())
388 }
389
390 pub fn with_config(
391 discovery: Arc<dyn ServiceDiscovery>,
392 strategy: S,
393 config: LoadBalancerConfig,
394 ) -> Self {
395 Self {
396 discovery,
397 strategy,
398 servers: Arc::new(RwLock::new(Vec::new())),
399 stream_affinity: RwLock::new(HashMap::new()),
400 config,
401 }
402 }
403
404 pub async fn init(&self) -> Result<()> {
405 let endpoints = self.discovery.discover().await?;
406 self.update_endpoints(endpoints);
407 Ok(())
408 }
409
410 pub fn start(&self) -> LoadBalancerHandle {
411 let mut handles = Vec::new();
412
413 if let Some(mut rx) = self.discovery.watch() {
414 let servers = self.servers.clone();
415 let h = tokio::spawn(async move {
416 while let Ok(event) = rx.recv().await {
417 match event {
418 DiscoveryEvent::Updated(endpoints) => {
419 Self::update_endpoints_static(&servers, endpoints);
420 }
421 DiscoveryEvent::Added(endpoint) => {
422 servers.write().push(ServerState::new(endpoint));
423 }
424 DiscoveryEvent::Removed(endpoint) => {
425 servers.write().retain(|s| s.endpoint != endpoint);
426 }
427 }
428 }
429 });
430 handles.push(h);
431 }
432
433 LoadBalancerHandle { handles }
434 }
435
436 fn update_endpoints(&self, endpoints: Vec<Endpoint>) {
437 Self::update_endpoints_static(&self.servers, endpoints);
438 }
439
440 fn update_endpoints_static(
441 servers: &RwLock<Vec<ServerState<S::Status>>>,
442 endpoints: Vec<Endpoint>,
443 ) {
444 let mut servers = servers.write();
445 servers.retain(|s| endpoints.contains(&s.endpoint));
446 for ep in endpoints {
447 if !servers.iter().any(|s| s.endpoint == ep) {
448 servers.push(ServerState::new(ep));
449 }
450 }
451 }
452
453 pub fn select(&self) -> Option<usize> {
454 let servers = self.servers.read();
455 self.strategy.select(&servers)
456 }
457
458 pub fn select_for_stream(&self, stream_id: StreamId) -> Option<usize> {
459 {
460 let affinity = self.stream_affinity.read();
461 if let Some(&idx) = affinity.get(&stream_id) {
462 let servers = self.servers.read();
463 if servers.get(idx).map_or(false, |s| s.is_available()) {
464 return Some(idx);
465 }
466 }
467 }
468
469 let idx = self.select()?;
470 self.stream_affinity.write().insert(stream_id, idx);
471 Some(idx)
472 }
473
474 pub fn release_stream(&self, stream_id: StreamId) {
475 self.stream_affinity.write().remove(&stream_id);
476 }
477
478 pub fn get_endpoint(&self, server_idx: usize) -> Option<Endpoint> {
479 self.servers
480 .read()
481 .get(server_idx)
482 .map(|s| s.endpoint.clone())
483 }
484
485 pub fn report_status(&self, server_idx: usize, status: S::Status) {
486 if let Some(server) = self.servers.write().get_mut(server_idx) {
487 server.status = Some(status.clone());
488 server.last_update = Instant::now();
489 }
490 self.strategy.update_status(server_idx, status);
491 }
492
493 pub fn record_success(&self, server_idx: usize) {
494 if let Some(server) = self.servers.write().get_mut(server_idx) {
495 server.record_success();
496 server.health = ServerHealth::Healthy;
497 }
498 self.strategy.on_success(server_idx);
499 }
500
501 pub fn record_failure(&self, server_idx: usize) {
502 let should_mark_unhealthy = {
503 let mut servers = self.servers.write();
504 if let Some(server) = servers.get_mut(server_idx) {
505 server.record_failure();
506 server.consecutive_failures >= self.config.max_failures
507 } else {
508 false
509 }
510 };
511
512 if should_mark_unhealthy {
513 self.mark_unhealthy(server_idx);
514 }
515
516 self.strategy.on_failure(server_idx);
517 }
518
519 pub fn mark_unhealthy(&self, server_idx: usize) {
520 if let Some(server) = self.servers.write().get_mut(server_idx) {
521 server.health = ServerHealth::Unhealthy;
522 }
523 }
524
525 pub fn mark_healthy(&self, server_idx: usize) {
526 if let Some(server) = self.servers.write().get_mut(server_idx) {
527 server.health = ServerHealth::Healthy;
528 server.consecutive_failures = 0;
529 }
530 }
531
532 pub fn available_count(&self) -> usize {
533 self.servers
534 .read()
535 .iter()
536 .filter(|s| s.is_available())
537 .count()
538 }
539
540 pub fn server_count(&self) -> usize {
541 self.servers.read().len()
542 }
543
544 pub fn strategy_name(&self) -> &'static str {
545 self.strategy.name()
546 }
547
548 pub fn acquire(&self, server_idx: usize) {
549 if let Some(server) = self.servers.write().get_mut(server_idx) {
550 server.active_requests += 1;
551 }
552 }
553
554 pub fn release(&self, server_idx: usize) {
555 if let Some(server) = self.servers.write().get_mut(server_idx) {
556 server.active_requests = server.active_requests.saturating_sub(1);
557 }
558 }
559
560 pub fn config(&self) -> &LoadBalancerConfig {
561 &self.config
562 }
563}
564
565pub struct LoadBalancerHandle {
566 handles: Vec<tokio::task::JoinHandle<()>>,
567}
568
569impl LoadBalancerHandle {
570 pub async fn shutdown(self) {
571 for h in self.handles {
572 h.abort();
573 let _ = h.await;
574 }
575 }
576}
577
578#[cfg(test)]
579mod tests {
580 use super::*;
581 use crate::discovery::StaticDiscovery;
582
583 fn create_test_servers(count: usize) -> Vec<ServerState<()>> {
584 (0..count)
585 .map(|i| {
586 ServerState::new(Endpoint::tcp_from_str(&format!("127.0.0.1:800{}", i)).unwrap())
587 })
588 .collect()
589 }
590
591 #[test]
592 fn test_round_robin() {
593 let strategy = RoundRobin::new();
594 let servers = create_test_servers(3);
595
596 let selections: Vec<_> = (0..6).filter_map(|_| strategy.select(&servers)).collect();
597 assert_eq!(selections, vec![0, 1, 2, 0, 1, 2]);
598 }
599
600 #[test]
601 fn test_round_robin_skip_unhealthy() {
602 let strategy = RoundRobin::new();
603 let mut servers = create_test_servers(3);
604 servers[1].health = ServerHealth::Unhealthy;
605
606 let selections: Vec<_> = (0..4).filter_map(|_| strategy.select(&servers)).collect();
607 assert_eq!(selections, vec![0, 2, 0, 2]);
608 }
609
610 #[test]
611 fn test_least_connections() {
612 let strategy = LeastConnections::new();
613 let mut servers = create_test_servers(3);
614 servers[0].active_requests = 5;
615 servers[1].active_requests = 2;
616 servers[2].active_requests = 3;
617
618 assert_eq!(strategy.select(&servers), Some(1));
619 }
620
621 #[test]
622 fn test_weighted_round_robin() {
623 let strategy = WeightedRoundRobin::with_weights(vec![2, 1, 1]);
624 let servers = create_test_servers(3);
625
626 let mut counts = [0usize; 3];
627 for _ in 0..8 {
628 if let Some(idx) = strategy.select(&servers) {
629 counts[idx] += 1;
630 }
631 }
632
633 assert!(counts[0] > counts[1]);
634 assert!(counts[0] > counts[2]);
635 }
636
637 #[tokio::test]
638 async fn test_load_balancer_init() {
639 let endpoints = vec![
640 Endpoint::tcp_from_str("127.0.0.1:8001").unwrap(),
641 Endpoint::tcp_from_str("127.0.0.1:8002").unwrap(),
642 ];
643
644 let discovery = Arc::new(StaticDiscovery::new(endpoints));
645 let lb = LoadBalancer::new(discovery, RoundRobin::new());
646 lb.init().await.unwrap();
647
648 assert_eq!(lb.server_count(), 2);
649 assert_eq!(lb.available_count(), 2);
650 }
651
652 #[tokio::test]
653 async fn test_stream_affinity() {
654 let endpoints = vec![
655 Endpoint::tcp_from_str("127.0.0.1:8001").unwrap(),
656 Endpoint::tcp_from_str("127.0.0.1:8002").unwrap(),
657 ];
658
659 let discovery = Arc::new(StaticDiscovery::new(endpoints));
660 let lb = LoadBalancer::new(discovery, RoundRobin::new());
661 lb.init().await.unwrap();
662
663 let stream_id = 42;
664 let first = lb.select_for_stream(stream_id);
665 let second = lb.select_for_stream(stream_id);
666 let third = lb.select_for_stream(stream_id);
667
668 assert_eq!(first, second);
669 assert_eq!(second, third);
670
671 lb.release_stream(stream_id);
672 }
673
674 #[tokio::test]
675 async fn test_failure_tracking() {
676 let endpoints = vec![Endpoint::tcp_from_str("127.0.0.1:8001").unwrap()];
677 let discovery = Arc::new(StaticDiscovery::new(endpoints));
678 let config = LoadBalancerConfig {
679 max_failures: 2,
680 ..Default::default()
681 };
682 let lb = LoadBalancer::with_config(discovery, RoundRobin::new(), config);
683 lb.init().await.unwrap();
684
685 lb.record_failure(0);
686 assert_eq!(lb.available_count(), 1);
687
688 lb.record_failure(0);
689 assert_eq!(lb.available_count(), 0);
690
691 lb.mark_healthy(0);
692 assert_eq!(lb.available_count(), 1);
693 }
694}