1use async_trait::async_trait;
2use std::collections::HashMap;
3use std::sync::atomic::{AtomicU64, Ordering};
4use std::sync::Arc;
5use std::time::{Duration, Instant};
6use tokio::sync::RwLock;
7use tracing::{debug, info, trace, warn};
8
9use super::{LoadBalancer, RequestContext, TargetSelection, UpstreamTarget};
10use sentinel_common::errors::{SentinelError, SentinelResult};
11
12#[derive(Debug, Clone)]
14pub struct AdaptiveConfig {
15 pub adjustment_interval: Duration,
17 pub min_weight_ratio: f64,
19 pub max_weight_ratio: f64,
21 pub error_threshold: f64,
23 pub latency_threshold: Duration,
25 pub ewma_decay: f64,
27 pub recovery_rate: f64,
29 pub penalty_rate: f64,
31 pub circuit_breaker: bool,
33 pub min_requests: u64,
35}
36
37impl Default for AdaptiveConfig {
38 fn default() -> Self {
39 Self {
40 adjustment_interval: Duration::from_secs(10),
41 min_weight_ratio: 0.1, max_weight_ratio: 2.0, error_threshold: 0.05, latency_threshold: Duration::from_millis(500),
45 ewma_decay: 0.8, recovery_rate: 1.1, penalty_rate: 0.7, circuit_breaker: true,
49 min_requests: 100,
50 }
51 }
52}
53
54#[derive(Debug, Clone)]
56struct PerformanceMetrics {
57 total_requests: Arc<AtomicU64>,
59 failed_requests: Arc<AtomicU64>,
61 total_latency_us: Arc<AtomicU64>,
63 success_count: Arc<AtomicU64>,
65 active_connections: Arc<AtomicU64>,
67 effective_weight: Arc<RwLock<f64>>,
69 ewma_error_rate: Arc<RwLock<f64>>,
71 ewma_latency: Arc<RwLock<f64>>,
73 last_adjustment: Arc<RwLock<Instant>>,
75 consecutive_successes: Arc<AtomicU64>,
77 consecutive_failures: Arc<AtomicU64>,
79 circuit_open: Arc<RwLock<bool>>,
81 last_error: Arc<RwLock<Option<Instant>>>,
83}
84
85impl PerformanceMetrics {
86 fn new(initial_weight: f64) -> Self {
87 Self {
88 total_requests: Arc::new(AtomicU64::new(0)),
89 failed_requests: Arc::new(AtomicU64::new(0)),
90 total_latency_us: Arc::new(AtomicU64::new(0)),
91 success_count: Arc::new(AtomicU64::new(0)),
92 active_connections: Arc::new(AtomicU64::new(0)),
93 effective_weight: Arc::new(RwLock::new(initial_weight)),
94 ewma_error_rate: Arc::new(RwLock::new(0.0)),
95 ewma_latency: Arc::new(RwLock::new(0.0)),
96 last_adjustment: Arc::new(RwLock::new(Instant::now())),
97 consecutive_successes: Arc::new(AtomicU64::new(0)),
98 consecutive_failures: Arc::new(AtomicU64::new(0)),
99 circuit_open: Arc::new(RwLock::new(false)),
100 last_error: Arc::new(RwLock::new(None)),
101 }
102 }
103
104 async fn update_ewma(&self, error_rate: f64, latency_us: f64, decay: f64) {
106 let mut ewma_error = self.ewma_error_rate.write().await;
107 *ewma_error = decay * error_rate + (1.0 - decay) * (*ewma_error);
108
109 let mut ewma_lat = self.ewma_latency.write().await;
110 *ewma_lat = decay * latency_us + (1.0 - decay) * (*ewma_lat);
111 }
112
113 async fn record_result(
115 &self,
116 success: bool,
117 latency: Option<Duration>,
118 config: &AdaptiveConfig,
119 ) {
120 self.total_requests.fetch_add(1, Ordering::Relaxed);
121
122 if success {
123 self.consecutive_successes.fetch_add(1, Ordering::Relaxed);
124 self.consecutive_failures.store(0, Ordering::Relaxed);
125
126 if let Some(lat) = latency {
127 let lat_us = lat.as_micros() as u64;
128 self.total_latency_us.fetch_add(lat_us, Ordering::Relaxed);
129 self.success_count.fetch_add(1, Ordering::Relaxed);
130 }
131
132 if config.circuit_breaker {
134 let successes = self.consecutive_successes.load(Ordering::Relaxed);
135 if successes >= 5 && *self.circuit_open.read().await {
136 *self.circuit_open.write().await = false;
137 info!(
138 "Circuit breaker closed after {} consecutive successes",
139 successes
140 );
141 }
142 }
143 } else {
144 self.failed_requests.fetch_add(1, Ordering::Relaxed);
145 self.consecutive_failures.fetch_add(1, Ordering::Relaxed);
146 self.consecutive_successes.store(0, Ordering::Relaxed);
147 *self.last_error.write().await = Some(Instant::now());
148
149 if config.circuit_breaker {
151 let failures = self.consecutive_failures.load(Ordering::Relaxed);
152 if failures >= 5 && !*self.circuit_open.read().await {
153 *self.circuit_open.write().await = true;
154 warn!(
155 "Circuit breaker opened after {} consecutive failures",
156 failures
157 );
158 }
159 }
160 }
161 }
162
163 fn current_error_rate(&self) -> f64 {
165 let total = self.total_requests.load(Ordering::Relaxed);
166 if total == 0 {
167 return 0.0;
168 }
169 let failed = self.failed_requests.load(Ordering::Relaxed);
170 failed as f64 / total as f64
171 }
172
173 fn average_latency(&self) -> Duration {
175 let count = self.success_count.load(Ordering::Relaxed);
176 if count == 0 {
177 return Duration::ZERO;
178 }
179 let total_us = self.total_latency_us.load(Ordering::Relaxed);
180 Duration::from_micros(total_us / count)
181 }
182
183 fn reset_interval_metrics(&self) {
185 self.total_requests.store(0, Ordering::Relaxed);
186 self.failed_requests.store(0, Ordering::Relaxed);
187 self.total_latency_us.store(0, Ordering::Relaxed);
188 self.success_count.store(0, Ordering::Relaxed);
189 }
190}
191
192#[derive(Debug, Clone)]
194struct TargetScore {
195 index: usize,
196 score: f64,
197 weight: f64,
198}
199
200pub struct AdaptiveBalancer {
202 config: AdaptiveConfig,
204 targets: Vec<UpstreamTarget>,
206 original_weights: Vec<f64>,
208 metrics: Vec<PerformanceMetrics>,
210 health_status: Arc<RwLock<HashMap<String, bool>>>,
212 last_global_adjustment: Arc<RwLock<Instant>>,
214}
215
216impl AdaptiveBalancer {
217 pub fn new(targets: Vec<UpstreamTarget>, config: AdaptiveConfig) -> Self {
218 trace!(
219 target_count = targets.len(),
220 adjustment_interval_secs = config.adjustment_interval.as_secs(),
221 min_weight_ratio = config.min_weight_ratio,
222 max_weight_ratio = config.max_weight_ratio,
223 error_threshold = config.error_threshold,
224 latency_threshold_ms = config.latency_threshold.as_millis() as u64,
225 ewma_decay = config.ewma_decay,
226 circuit_breaker = config.circuit_breaker,
227 min_requests = config.min_requests,
228 "Creating adaptive balancer"
229 );
230
231 let original_weights: Vec<f64> = targets.iter().map(|t| t.weight as f64).collect();
232 let metrics = original_weights
233 .iter()
234 .map(|&w| PerformanceMetrics::new(w))
235 .collect();
236
237 debug!(
238 target_count = targets.len(),
239 total_weight = original_weights.iter().sum::<f64>(),
240 "Adaptive balancer initialized"
241 );
242
243 Self {
244 config,
245 targets,
246 original_weights,
247 metrics,
248 health_status: Arc::new(RwLock::new(HashMap::new())),
249 last_global_adjustment: Arc::new(RwLock::new(Instant::now())),
250 }
251 }
252
253 async fn adjust_weights(&self) {
255 let mut last_adjustment = self.last_global_adjustment.write().await;
256
257 let elapsed = last_adjustment.elapsed();
258 if elapsed < self.config.adjustment_interval {
259 trace!(
260 elapsed_secs = elapsed.as_secs(),
261 interval_secs = self.config.adjustment_interval.as_secs(),
262 "Skipping weight adjustment (interval not reached)"
263 );
264 return;
265 }
266
267 debug!(
268 elapsed_secs = elapsed.as_secs(),
269 target_count = self.targets.len(),
270 "Adjusting weights based on performance metrics"
271 );
272
273 for (i, metric) in self.metrics.iter().enumerate() {
274 let requests = metric.total_requests.load(Ordering::Relaxed);
275
276 if requests < self.config.min_requests {
278 continue;
279 }
280
281 let error_rate = metric.current_error_rate();
283 let avg_latency = metric.average_latency();
284 let latency_us = avg_latency.as_micros() as f64;
285
286 metric
288 .update_ewma(error_rate, latency_us, self.config.ewma_decay)
289 .await;
290
291 let ewma_error = *metric.ewma_error_rate.read().await;
293 let ewma_latency_us = *metric.ewma_latency.read().await;
294 let ewma_latency = Duration::from_micros(ewma_latency_us as u64);
295
296 let mut adjustment = 1.0;
298
299 if ewma_error > self.config.error_threshold {
301 let error_factor =
302 1.0 - ((ewma_error - self.config.error_threshold) * 10.0).min(0.9);
303 adjustment *= error_factor;
304 debug!(
305 "Target {} error rate {:.2}% exceeds threshold, factor: {:.2}",
306 i,
307 ewma_error * 100.0,
308 error_factor
309 );
310 }
311
312 if ewma_latency > self.config.latency_threshold {
314 let latency_ratio =
315 self.config.latency_threshold.as_micros() as f64 / ewma_latency_us;
316 adjustment *= latency_ratio.max(0.1);
317 debug!(
318 "Target {} latency {:?} exceeds threshold, factor: {:.2}",
319 i, ewma_latency, latency_ratio
320 );
321 }
322
323 let mut current_weight = *metric.effective_weight.read().await;
325 let original = self.original_weights[i];
326
327 if adjustment < 1.0 {
328 current_weight *=
330 self.config.penalty_rate + (1.0 - self.config.penalty_rate) * adjustment;
331 } else {
332 current_weight *= self.config.recovery_rate;
334 }
335
336 let min_weight = original * self.config.min_weight_ratio;
338 let max_weight = original * self.config.max_weight_ratio;
339 current_weight = current_weight.max(min_weight).min(max_weight);
340
341 *metric.effective_weight.write().await = current_weight;
342
343 info!(
344 "Adjusted weight for target {}: {:.2} (original: {:.2}, error: {:.2}%, latency: {:.0}ms)",
345 i,
346 current_weight,
347 original,
348 ewma_error * 100.0,
349 ewma_latency.as_millis()
350 );
351
352 metric.reset_interval_metrics();
354 }
355
356 *last_adjustment = Instant::now();
357 }
358
359 async fn calculate_scores(&self) -> Vec<TargetScore> {
361 trace!(
362 target_count = self.targets.len(),
363 "Calculating scores for all targets"
364 );
365
366 let health = self.health_status.read().await;
367 let mut scores = Vec::new();
368
369 for (i, target) in self.targets.iter().enumerate() {
370 let target_id = format!("{}:{}", target.address, target.port);
371 let is_healthy = health.get(&target_id).copied().unwrap_or(true);
372 let circuit_open = *self.metrics[i].circuit_open.read().await;
373
374 if !is_healthy || circuit_open {
376 trace!(
377 target_index = i,
378 target_id = %target_id,
379 is_healthy = is_healthy,
380 circuit_open = circuit_open,
381 "Skipping target from scoring"
382 );
383 continue;
384 }
385
386 let weight = *self.metrics[i].effective_weight.read().await;
387 let connections = self.metrics[i].active_connections.load(Ordering::Relaxed) as f64;
388 let ewma_error = *self.metrics[i].ewma_error_rate.read().await;
389 let ewma_latency = *self.metrics[i].ewma_latency.read().await / 1000.0; let error_penalty = ewma_error * 100.0; let latency_penalty = (ewma_latency / 10.0).max(0.0); let score = weight / (1.0 + connections + error_penalty + latency_penalty);
395
396 trace!(
397 target_index = i,
398 target_id = %target_id,
399 weight = weight,
400 connections = connections,
401 ewma_error = ewma_error,
402 ewma_latency_ms = ewma_latency,
403 error_penalty = error_penalty,
404 latency_penalty = latency_penalty,
405 score = score,
406 "Calculated target score"
407 );
408
409 scores.push(TargetScore {
410 index: i,
411 score,
412 weight,
413 });
414 }
415
416 scores.sort_by(|a, b| {
418 b.score
419 .partial_cmp(&a.score)
420 .unwrap_or(std::cmp::Ordering::Equal)
421 });
422
423 trace!(
424 scored_count = scores.len(),
425 top_score = scores.first().map(|s| s.score).unwrap_or(0.0),
426 "Scores calculated and sorted"
427 );
428
429 scores
430 }
431
432 async fn weighted_select(&self, scores: &[TargetScore]) -> Option<usize> {
434 if scores.is_empty() {
435 trace!("No scores provided for weighted selection");
436 return None;
437 }
438
439 let total_score: f64 = scores.iter().map(|s| s.score).sum();
441 if total_score <= 0.0 {
442 trace!(
443 fallback_index = scores[0].index,
444 "Total score is zero, using fallback"
445 );
446 return Some(scores[0].index); }
448
449 use rand::prelude::*;
451 let mut rng = thread_rng();
452 let threshold = rng.gen::<f64>() * total_score;
453
454 trace!(
455 total_score = total_score,
456 threshold = threshold,
457 candidate_count = scores.len(),
458 "Performing weighted random selection"
459 );
460
461 let mut cumulative = 0.0;
462 for score in scores {
463 cumulative += score.score;
464 if cumulative >= threshold {
465 trace!(
466 selected_index = score.index,
467 selected_score = score.score,
468 cumulative = cumulative,
469 "Selected target via weighted random"
470 );
471 return Some(score.index);
472 }
473 }
474
475 let fallback = scores.last().map(|s| s.index);
477 trace!(
478 fallback_index = ?fallback,
479 "Using fallback selection (floating point edge case)"
480 );
481 fallback
482 }
483}
484
485#[async_trait]
486impl LoadBalancer for AdaptiveBalancer {
487 async fn select(&self, _context: Option<&RequestContext>) -> SentinelResult<TargetSelection> {
488 trace!("Adaptive select started");
489
490 self.adjust_weights().await;
492
493 let scores = self.calculate_scores().await;
495
496 if scores.is_empty() {
497 warn!("Adaptive: No healthy targets available");
498 return Err(SentinelError::NoHealthyUpstream);
499 }
500
501 let target_index = self
503 .weighted_select(&scores)
504 .await
505 .ok_or_else(|| {
506 warn!("Adaptive: Failed to select from scores");
507 SentinelError::NoHealthyUpstream
508 })?;
509
510 let target = &self.targets[target_index];
511 let metrics = &self.metrics[target_index];
512
513 let connections = metrics.active_connections.fetch_add(1, Ordering::Relaxed) + 1;
515
516 let effective_weight = *metrics.effective_weight.read().await;
517 let ewma_error = *metrics.ewma_error_rate.read().await;
518 let ewma_latency = Duration::from_micros(*metrics.ewma_latency.read().await as u64);
519
520 let score = scores
521 .iter()
522 .find(|s| s.index == target_index)
523 .map(|s| s.score)
524 .unwrap_or(0.0);
525
526 debug!(
527 target = %format!("{}:{}", target.address, target.port),
528 target_index = target_index,
529 score = score,
530 effective_weight = effective_weight,
531 original_weight = self.original_weights[target_index],
532 error_rate = ewma_error,
533 latency_ms = ewma_latency.as_millis() as u64,
534 connections = connections,
535 "Adaptive selected target"
536 );
537
538 Ok(TargetSelection {
539 address: format!("{}:{}", target.address, target.port),
540 weight: target.weight,
541 metadata: {
542 let mut meta = HashMap::new();
543 meta.insert("algorithm".to_string(), "adaptive".to_string());
544 meta.insert("target_index".to_string(), target_index.to_string());
545 meta.insert(
546 "effective_weight".to_string(),
547 format!("{:.2}", effective_weight),
548 );
549 meta.insert(
550 "original_weight".to_string(),
551 self.original_weights[target_index].to_string(),
552 );
553 meta.insert("error_rate".to_string(), format!("{:.4}", ewma_error));
554 meta.insert(
555 "latency_ms".to_string(),
556 format!("{:.2}", ewma_latency.as_millis()),
557 );
558 meta.insert("connections".to_string(), connections.to_string());
559 meta
560 },
561 })
562 }
563
564 async fn report_health(&self, address: &str, healthy: bool) {
565 trace!(
566 address = %address,
567 healthy = healthy,
568 "Adaptive reporting target health"
569 );
570
571 let mut health = self.health_status.write().await;
572 let previous = health.insert(address.to_string(), healthy);
573
574 if previous != Some(healthy) {
575 info!(
576 address = %address,
577 previous = ?previous,
578 healthy = healthy,
579 "Adaptive target health changed"
580 );
581
582 for (i, target) in self.targets.iter().enumerate() {
584 let target_id = format!("{}:{}", target.address, target.port);
585 if target_id == address {
586 if healthy {
587 let original = self.original_weights[i];
589 *self.metrics[i].effective_weight.write().await = original;
590 *self.metrics[i].circuit_open.write().await = false;
591 self.metrics[i]
592 .consecutive_failures
593 .store(0, Ordering::Relaxed);
594 info!(
595 target_index = i,
596 original_weight = original,
597 "Reset target to original weight on recovery"
598 );
599 }
600 break;
601 }
602 }
603 }
604 }
605
606 async fn healthy_targets(&self) -> Vec<String> {
607 let health = self.health_status.read().await;
608 let mut targets = Vec::new();
609
610 for (i, target) in self.targets.iter().enumerate() {
611 let target_id = format!("{}:{}", target.address, target.port);
612 let is_healthy = health.get(&target_id).copied().unwrap_or(true);
613 let circuit_open = *self.metrics[i].circuit_open.read().await;
614
615 if is_healthy && !circuit_open {
616 targets.push(target_id);
617 }
618 }
619
620 trace!(
621 total = self.targets.len(),
622 healthy = targets.len(),
623 "Adaptive healthy targets"
624 );
625
626 targets
627 }
628
629 async fn release(&self, selection: &TargetSelection) {
630 if let Some(index_str) = selection.metadata.get("target_index") {
631 if let Ok(index) = index_str.parse::<usize>() {
632 let connections = self.metrics[index]
633 .active_connections
634 .fetch_sub(1, Ordering::Relaxed) - 1;
635 trace!(
636 target_index = index,
637 address = %selection.address,
638 connections = connections,
639 "Adaptive released connection"
640 );
641 }
642 }
643 }
644
645 async fn report_result(
646 &self,
647 selection: &TargetSelection,
648 success: bool,
649 latency: Option<Duration>,
650 ) {
651 if let Some(index_str) = selection.metadata.get("target_index") {
652 if let Ok(index) = index_str.parse::<usize>() {
653 trace!(
654 target_index = index,
655 address = %selection.address,
656 success = success,
657 latency_ms = latency.map(|l| l.as_millis() as u64),
658 "Adaptive recording result"
659 );
660 self.metrics[index]
661 .record_result(success, latency, &self.config)
662 .await;
663 }
664 }
665 }
666}
667
668#[cfg(test)]
669mod tests {
670 use super::*;
671
672 fn create_test_targets(count: usize) -> Vec<UpstreamTarget> {
673 (0..count)
674 .map(|i| UpstreamTarget {
675 address: format!("10.0.0.{}", i + 1),
676 port: 8080,
677 weight: 100,
678 })
679 .collect()
680 }
681
682 #[tokio::test]
683 async fn test_weight_degradation() {
684 let targets = create_test_targets(3);
685 let config = AdaptiveConfig {
686 adjustment_interval: Duration::from_millis(10),
687 min_requests: 1,
688 ..Default::default()
689 };
690 let balancer = AdaptiveBalancer::new(targets, config);
691
692 for _ in 0..10 {
694 balancer.metrics[0]
695 .record_result(false, None, &balancer.config)
696 .await;
697 }
698 balancer.metrics[0]
699 .total_requests
700 .store(10, Ordering::Relaxed);
701
702 for _ in 0..10 {
704 balancer.metrics[1]
705 .record_result(true, Some(Duration::from_millis(10)), &balancer.config)
706 .await;
707 }
708 balancer.metrics[1]
709 .total_requests
710 .store(10, Ordering::Relaxed);
711
712 tokio::time::sleep(Duration::from_millis(15)).await;
714
715 balancer.adjust_weights().await;
717
718 let weight0 = *balancer.metrics[0].effective_weight.read().await;
720 let weight1 = *balancer.metrics[1].effective_weight.read().await;
721
722 assert!(weight0 < 100.0, "Target 0 weight should be degraded");
723 assert!(weight1 >= 100.0, "Target 1 weight should not be degraded");
724 }
725
726 #[tokio::test]
727 async fn test_circuit_breaker() {
728 let targets = create_test_targets(2);
729 let config = AdaptiveConfig::default();
730 let balancer = AdaptiveBalancer::new(targets, config);
731
732 for _ in 0..5 {
734 balancer.metrics[0]
735 .record_result(false, None, &balancer.config)
736 .await;
737 }
738
739 assert!(*balancer.metrics[0].circuit_open.read().await);
741
742 let scores = balancer.calculate_scores().await;
744 assert!(!scores.iter().any(|s| s.index == 0));
745
746 for _ in 0..5 {
748 balancer.metrics[0]
749 .record_result(true, Some(Duration::from_millis(10)), &balancer.config)
750 .await;
751 }
752
753 assert!(!*balancer.metrics[0].circuit_open.read().await);
755 }
756
757 #[tokio::test]
758 async fn test_latency_penalty() {
759 let targets = create_test_targets(2);
760 let config = AdaptiveConfig {
761 adjustment_interval: Duration::from_millis(10),
762 min_requests: 1,
763 latency_threshold: Duration::from_millis(100),
764 ..Default::default()
765 };
766 let balancer = AdaptiveBalancer::new(targets, config);
767
768 for _ in 0..10 {
770 balancer.metrics[0]
771 .record_result(true, Some(Duration::from_millis(500)), &balancer.config)
772 .await;
773 }
774 balancer.metrics[0]
775 .total_requests
776 .store(10, Ordering::Relaxed);
777
778 for _ in 0..10 {
780 balancer.metrics[1]
781 .record_result(true, Some(Duration::from_millis(50)), &balancer.config)
782 .await;
783 }
784 balancer.metrics[1]
785 .total_requests
786 .store(10, Ordering::Relaxed);
787
788 tokio::time::sleep(Duration::from_millis(15)).await;
789 balancer.adjust_weights().await;
790
791 let weight0 = *balancer.metrics[0].effective_weight.read().await;
792 let weight1 = *balancer.metrics[1].effective_weight.read().await;
793
794 assert!(
795 weight0 < weight1,
796 "High latency target should have lower weight"
797 );
798 }
799}