1use async_trait::async_trait;
2use std::collections::HashMap;
3use std::sync::atomic::{AtomicU64, Ordering};
4use std::sync::Arc;
5use std::time::{Duration, Instant};
6use tokio::sync::RwLock;
7use tracing::{debug, info, warn};
8
9use super::{LoadBalancer, RequestContext, TargetSelection, UpstreamTarget};
10use sentinel_common::errors::{SentinelError, SentinelResult};
11
12#[derive(Debug, Clone)]
14pub struct AdaptiveConfig {
15 pub adjustment_interval: Duration,
17 pub min_weight_ratio: f64,
19 pub max_weight_ratio: f64,
21 pub error_threshold: f64,
23 pub latency_threshold: Duration,
25 pub ewma_decay: f64,
27 pub recovery_rate: f64,
29 pub penalty_rate: f64,
31 pub circuit_breaker: bool,
33 pub min_requests: u64,
35}
36
37impl Default for AdaptiveConfig {
38 fn default() -> Self {
39 Self {
40 adjustment_interval: Duration::from_secs(10),
41 min_weight_ratio: 0.1, max_weight_ratio: 2.0, error_threshold: 0.05, latency_threshold: Duration::from_millis(500),
45 ewma_decay: 0.8, recovery_rate: 1.1, penalty_rate: 0.7, circuit_breaker: true,
49 min_requests: 100,
50 }
51 }
52}
53
54#[derive(Debug, Clone)]
56struct PerformanceMetrics {
57 total_requests: Arc<AtomicU64>,
59 failed_requests: Arc<AtomicU64>,
61 total_latency_us: Arc<AtomicU64>,
63 success_count: Arc<AtomicU64>,
65 active_connections: Arc<AtomicU64>,
67 effective_weight: Arc<RwLock<f64>>,
69 ewma_error_rate: Arc<RwLock<f64>>,
71 ewma_latency: Arc<RwLock<f64>>,
73 last_adjustment: Arc<RwLock<Instant>>,
75 consecutive_successes: Arc<AtomicU64>,
77 consecutive_failures: Arc<AtomicU64>,
79 circuit_open: Arc<RwLock<bool>>,
81 last_error: Arc<RwLock<Option<Instant>>>,
83}
84
85impl PerformanceMetrics {
86 fn new(initial_weight: f64) -> Self {
87 Self {
88 total_requests: Arc::new(AtomicU64::new(0)),
89 failed_requests: Arc::new(AtomicU64::new(0)),
90 total_latency_us: Arc::new(AtomicU64::new(0)),
91 success_count: Arc::new(AtomicU64::new(0)),
92 active_connections: Arc::new(AtomicU64::new(0)),
93 effective_weight: Arc::new(RwLock::new(initial_weight)),
94 ewma_error_rate: Arc::new(RwLock::new(0.0)),
95 ewma_latency: Arc::new(RwLock::new(0.0)),
96 last_adjustment: Arc::new(RwLock::new(Instant::now())),
97 consecutive_successes: Arc::new(AtomicU64::new(0)),
98 consecutive_failures: Arc::new(AtomicU64::new(0)),
99 circuit_open: Arc::new(RwLock::new(false)),
100 last_error: Arc::new(RwLock::new(None)),
101 }
102 }
103
104 async fn update_ewma(&self, error_rate: f64, latency_us: f64, decay: f64) {
106 let mut ewma_error = self.ewma_error_rate.write().await;
107 *ewma_error = decay * error_rate + (1.0 - decay) * (*ewma_error);
108
109 let mut ewma_lat = self.ewma_latency.write().await;
110 *ewma_lat = decay * latency_us + (1.0 - decay) * (*ewma_lat);
111 }
112
113 async fn record_result(
115 &self,
116 success: bool,
117 latency: Option<Duration>,
118 config: &AdaptiveConfig,
119 ) {
120 self.total_requests.fetch_add(1, Ordering::Relaxed);
121
122 if success {
123 self.consecutive_successes.fetch_add(1, Ordering::Relaxed);
124 self.consecutive_failures.store(0, Ordering::Relaxed);
125
126 if let Some(lat) = latency {
127 let lat_us = lat.as_micros() as u64;
128 self.total_latency_us.fetch_add(lat_us, Ordering::Relaxed);
129 self.success_count.fetch_add(1, Ordering::Relaxed);
130 }
131
132 if config.circuit_breaker {
134 let successes = self.consecutive_successes.load(Ordering::Relaxed);
135 if successes >= 5 && *self.circuit_open.read().await {
136 *self.circuit_open.write().await = false;
137 info!(
138 "Circuit breaker closed after {} consecutive successes",
139 successes
140 );
141 }
142 }
143 } else {
144 self.failed_requests.fetch_add(1, Ordering::Relaxed);
145 self.consecutive_failures.fetch_add(1, Ordering::Relaxed);
146 self.consecutive_successes.store(0, Ordering::Relaxed);
147 *self.last_error.write().await = Some(Instant::now());
148
149 if config.circuit_breaker {
151 let failures = self.consecutive_failures.load(Ordering::Relaxed);
152 if failures >= 5 && !*self.circuit_open.read().await {
153 *self.circuit_open.write().await = true;
154 warn!(
155 "Circuit breaker opened after {} consecutive failures",
156 failures
157 );
158 }
159 }
160 }
161 }
162
163 fn current_error_rate(&self) -> f64 {
165 let total = self.total_requests.load(Ordering::Relaxed);
166 if total == 0 {
167 return 0.0;
168 }
169 let failed = self.failed_requests.load(Ordering::Relaxed);
170 failed as f64 / total as f64
171 }
172
173 fn average_latency(&self) -> Duration {
175 let count = self.success_count.load(Ordering::Relaxed);
176 if count == 0 {
177 return Duration::ZERO;
178 }
179 let total_us = self.total_latency_us.load(Ordering::Relaxed);
180 Duration::from_micros(total_us / count)
181 }
182
183 fn reset_interval_metrics(&self) {
185 self.total_requests.store(0, Ordering::Relaxed);
186 self.failed_requests.store(0, Ordering::Relaxed);
187 self.total_latency_us.store(0, Ordering::Relaxed);
188 self.success_count.store(0, Ordering::Relaxed);
189 }
190}
191
192#[derive(Debug, Clone)]
194struct TargetScore {
195 index: usize,
196 score: f64,
197 weight: f64,
198}
199
200pub struct AdaptiveBalancer {
202 config: AdaptiveConfig,
204 targets: Vec<UpstreamTarget>,
206 original_weights: Vec<f64>,
208 metrics: Vec<PerformanceMetrics>,
210 health_status: Arc<RwLock<HashMap<String, bool>>>,
212 last_global_adjustment: Arc<RwLock<Instant>>,
214}
215
216impl AdaptiveBalancer {
217 pub fn new(targets: Vec<UpstreamTarget>, config: AdaptiveConfig) -> Self {
218 let original_weights: Vec<f64> = targets.iter().map(|t| t.weight as f64).collect();
219 let metrics = original_weights
220 .iter()
221 .map(|&w| PerformanceMetrics::new(w))
222 .collect();
223
224 Self {
225 config,
226 targets,
227 original_weights,
228 metrics,
229 health_status: Arc::new(RwLock::new(HashMap::new())),
230 last_global_adjustment: Arc::new(RwLock::new(Instant::now())),
231 }
232 }
233
234 async fn adjust_weights(&self) {
236 let mut last_adjustment = self.last_global_adjustment.write().await;
237
238 if last_adjustment.elapsed() < self.config.adjustment_interval {
239 return;
240 }
241
242 debug!("Adjusting weights based on performance metrics");
243
244 for (i, metric) in self.metrics.iter().enumerate() {
245 let requests = metric.total_requests.load(Ordering::Relaxed);
246
247 if requests < self.config.min_requests {
249 continue;
250 }
251
252 let error_rate = metric.current_error_rate();
254 let avg_latency = metric.average_latency();
255 let latency_us = avg_latency.as_micros() as f64;
256
257 metric
259 .update_ewma(error_rate, latency_us, self.config.ewma_decay)
260 .await;
261
262 let ewma_error = *metric.ewma_error_rate.read().await;
264 let ewma_latency_us = *metric.ewma_latency.read().await;
265 let ewma_latency = Duration::from_micros(ewma_latency_us as u64);
266
267 let mut adjustment = 1.0;
269
270 if ewma_error > self.config.error_threshold {
272 let error_factor =
273 1.0 - ((ewma_error - self.config.error_threshold) * 10.0).min(0.9);
274 adjustment *= error_factor;
275 debug!(
276 "Target {} error rate {:.2}% exceeds threshold, factor: {:.2}",
277 i,
278 ewma_error * 100.0,
279 error_factor
280 );
281 }
282
283 if ewma_latency > self.config.latency_threshold {
285 let latency_ratio =
286 self.config.latency_threshold.as_micros() as f64 / ewma_latency_us;
287 adjustment *= latency_ratio.max(0.1);
288 debug!(
289 "Target {} latency {:?} exceeds threshold, factor: {:.2}",
290 i, ewma_latency, latency_ratio
291 );
292 }
293
294 let mut current_weight = *metric.effective_weight.read().await;
296 let original = self.original_weights[i];
297
298 if adjustment < 1.0 {
299 current_weight *=
301 self.config.penalty_rate + (1.0 - self.config.penalty_rate) * adjustment;
302 } else {
303 current_weight *= self.config.recovery_rate;
305 }
306
307 let min_weight = original * self.config.min_weight_ratio;
309 let max_weight = original * self.config.max_weight_ratio;
310 current_weight = current_weight.max(min_weight).min(max_weight);
311
312 *metric.effective_weight.write().await = current_weight;
313
314 info!(
315 "Adjusted weight for target {}: {:.2} (original: {:.2}, error: {:.2}%, latency: {:.0}ms)",
316 i,
317 current_weight,
318 original,
319 ewma_error * 100.0,
320 ewma_latency.as_millis()
321 );
322
323 metric.reset_interval_metrics();
325 }
326
327 *last_adjustment = Instant::now();
328 }
329
330 async fn calculate_scores(&self) -> Vec<TargetScore> {
332 let health = self.health_status.read().await;
333 let mut scores = Vec::new();
334
335 for (i, target) in self.targets.iter().enumerate() {
336 let target_id = format!("{}:{}", target.address, target.port);
337 let is_healthy = health.get(&target_id).copied().unwrap_or(true);
338
339 if !is_healthy || *self.metrics[i].circuit_open.read().await {
341 continue;
342 }
343
344 let weight = *self.metrics[i].effective_weight.read().await;
345 let connections = self.metrics[i].active_connections.load(Ordering::Relaxed) as f64;
346 let ewma_error = *self.metrics[i].ewma_error_rate.read().await;
347 let ewma_latency = *self.metrics[i].ewma_latency.read().await / 1000.0; let error_penalty = ewma_error * 100.0; let latency_penalty = (ewma_latency / 10.0).max(0.0); let score = weight / (1.0 + connections + error_penalty + latency_penalty);
353
354 scores.push(TargetScore {
355 index: i,
356 score,
357 weight,
358 });
359 }
360
361 scores.sort_by(|a, b| {
363 b.score
364 .partial_cmp(&a.score)
365 .unwrap_or(std::cmp::Ordering::Equal)
366 });
367
368 scores
369 }
370
371 async fn weighted_select(&self, scores: &[TargetScore]) -> Option<usize> {
373 if scores.is_empty() {
374 return None;
375 }
376
377 let total_score: f64 = scores.iter().map(|s| s.score).sum();
379 if total_score <= 0.0 {
380 return Some(scores[0].index); }
382
383 use rand::prelude::*;
385 let mut rng = thread_rng();
386 let threshold = rng.gen::<f64>() * total_score;
387
388 let mut cumulative = 0.0;
389 for score in scores {
390 cumulative += score.score;
391 if cumulative >= threshold {
392 return Some(score.index);
393 }
394 }
395
396 scores.last().map(|s| s.index)
398 }
399}
400
401#[async_trait]
402impl LoadBalancer for AdaptiveBalancer {
403 async fn select(&self, _context: Option<&RequestContext>) -> SentinelResult<TargetSelection> {
404 self.adjust_weights().await;
406
407 let scores = self.calculate_scores().await;
409
410 if scores.is_empty() {
411 return Err(SentinelError::NoHealthyUpstream);
412 }
413
414 let target_index = self
416 .weighted_select(&scores)
417 .await
418 .ok_or(SentinelError::NoHealthyUpstream)?;
419
420 let target = &self.targets[target_index];
421 let metrics = &self.metrics[target_index];
422
423 metrics.active_connections.fetch_add(1, Ordering::Relaxed);
425
426 let effective_weight = *metrics.effective_weight.read().await;
427 let ewma_error = *metrics.ewma_error_rate.read().await;
428 let ewma_latency = Duration::from_micros(*metrics.ewma_latency.read().await as u64);
429
430 debug!(
431 "Adaptive selected target {} with score {:.2}, weight {:.2}",
432 target_index,
433 scores
434 .iter()
435 .find(|s| s.index == target_index)
436 .map(|s| s.score)
437 .unwrap_or(0.0),
438 effective_weight
439 );
440
441 Ok(TargetSelection {
442 address: format!("{}:{}", target.address, target.port),
443 weight: target.weight,
444 metadata: {
445 let mut meta = HashMap::new();
446 meta.insert("algorithm".to_string(), "adaptive".to_string());
447 meta.insert("target_index".to_string(), target_index.to_string());
448 meta.insert(
449 "effective_weight".to_string(),
450 format!("{:.2}", effective_weight),
451 );
452 meta.insert(
453 "original_weight".to_string(),
454 self.original_weights[target_index].to_string(),
455 );
456 meta.insert("error_rate".to_string(), format!("{:.4}", ewma_error));
457 meta.insert(
458 "latency_ms".to_string(),
459 format!("{:.2}", ewma_latency.as_millis()),
460 );
461 meta.insert(
462 "connections".to_string(),
463 metrics
464 .active_connections
465 .load(Ordering::Relaxed)
466 .to_string(),
467 );
468 meta
469 },
470 })
471 }
472
473 async fn report_health(&self, address: &str, healthy: bool) {
474 let mut health = self.health_status.write().await;
475 let previous = health.insert(address.to_string(), healthy);
476
477 if previous != Some(healthy) {
478 info!(
479 "Adaptive: Target {} health changed from {:?} to {}",
480 address, previous, healthy
481 );
482
483 for (i, target) in self.targets.iter().enumerate() {
485 let target_id = format!("{}:{}", target.address, target.port);
486 if target_id == address {
487 if healthy {
488 let original = self.original_weights[i];
490 *self.metrics[i].effective_weight.write().await = original;
491 *self.metrics[i].circuit_open.write().await = false;
492 self.metrics[i]
493 .consecutive_failures
494 .store(0, Ordering::Relaxed);
495 info!(
496 "Reset target {} to original weight {:.2} on recovery",
497 i, original
498 );
499 }
500 break;
501 }
502 }
503 }
504 }
505
506 async fn healthy_targets(&self) -> Vec<String> {
507 let health = self.health_status.read().await;
508 let mut targets = Vec::new();
509
510 for (i, target) in self.targets.iter().enumerate() {
511 let target_id = format!("{}:{}", target.address, target.port);
512 let is_healthy = health.get(&target_id).copied().unwrap_or(true);
513 let circuit_open = *self.metrics[i].circuit_open.read().await;
514
515 if is_healthy && !circuit_open {
516 targets.push(target_id);
517 }
518 }
519
520 targets
521 }
522
523 async fn release(&self, selection: &TargetSelection) {
524 if let Some(index_str) = selection.metadata.get("target_index") {
525 if let Ok(index) = index_str.parse::<usize>() {
526 self.metrics[index]
527 .active_connections
528 .fetch_sub(1, Ordering::Relaxed);
529 }
530 }
531 }
532
533 async fn report_result(
534 &self,
535 selection: &TargetSelection,
536 success: bool,
537 latency: Option<Duration>,
538 ) {
539 if let Some(index_str) = selection.metadata.get("target_index") {
540 if let Ok(index) = index_str.parse::<usize>() {
541 self.metrics[index]
542 .record_result(success, latency, &self.config)
543 .await;
544 }
545 }
546 }
547}
548
549#[cfg(test)]
550mod tests {
551 use super::*;
552
553 fn create_test_targets(count: usize) -> Vec<UpstreamTarget> {
554 (0..count)
555 .map(|i| UpstreamTarget {
556 address: format!("10.0.0.{}", i + 1),
557 port: 8080,
558 weight: 100,
559 })
560 .collect()
561 }
562
563 #[tokio::test]
564 async fn test_weight_degradation() {
565 let targets = create_test_targets(3);
566 let config = AdaptiveConfig {
567 adjustment_interval: Duration::from_millis(10),
568 min_requests: 1,
569 ..Default::default()
570 };
571 let balancer = AdaptiveBalancer::new(targets, config);
572
573 for _ in 0..10 {
575 balancer.metrics[0]
576 .record_result(false, None, &balancer.config)
577 .await;
578 }
579 balancer.metrics[0]
580 .total_requests
581 .store(10, Ordering::Relaxed);
582
583 for _ in 0..10 {
585 balancer.metrics[1]
586 .record_result(true, Some(Duration::from_millis(10)), &balancer.config)
587 .await;
588 }
589 balancer.metrics[1]
590 .total_requests
591 .store(10, Ordering::Relaxed);
592
593 tokio::time::sleep(Duration::from_millis(15)).await;
595
596 balancer.adjust_weights().await;
598
599 let weight0 = *balancer.metrics[0].effective_weight.read().await;
601 let weight1 = *balancer.metrics[1].effective_weight.read().await;
602
603 assert!(weight0 < 100.0, "Target 0 weight should be degraded");
604 assert!(weight1 >= 100.0, "Target 1 weight should not be degraded");
605 }
606
607 #[tokio::test]
608 async fn test_circuit_breaker() {
609 let targets = create_test_targets(2);
610 let config = AdaptiveConfig::default();
611 let balancer = AdaptiveBalancer::new(targets, config);
612
613 for _ in 0..5 {
615 balancer.metrics[0]
616 .record_result(false, None, &balancer.config)
617 .await;
618 }
619
620 assert!(*balancer.metrics[0].circuit_open.read().await);
622
623 let scores = balancer.calculate_scores().await;
625 assert!(!scores.iter().any(|s| s.index == 0));
626
627 for _ in 0..5 {
629 balancer.metrics[0]
630 .record_result(true, Some(Duration::from_millis(10)), &balancer.config)
631 .await;
632 }
633
634 assert!(!*balancer.metrics[0].circuit_open.read().await);
636 }
637
638 #[tokio::test]
639 async fn test_latency_penalty() {
640 let targets = create_test_targets(2);
641 let config = AdaptiveConfig {
642 adjustment_interval: Duration::from_millis(10),
643 min_requests: 1,
644 latency_threshold: Duration::from_millis(100),
645 ..Default::default()
646 };
647 let balancer = AdaptiveBalancer::new(targets, config);
648
649 for _ in 0..10 {
651 balancer.metrics[0]
652 .record_result(true, Some(Duration::from_millis(500)), &balancer.config)
653 .await;
654 }
655 balancer.metrics[0]
656 .total_requests
657 .store(10, Ordering::Relaxed);
658
659 for _ in 0..10 {
661 balancer.metrics[1]
662 .record_result(true, Some(Duration::from_millis(50)), &balancer.config)
663 .await;
664 }
665 balancer.metrics[1]
666 .total_requests
667 .store(10, Ordering::Relaxed);
668
669 tokio::time::sleep(Duration::from_millis(15)).await;
670 balancer.adjust_weights().await;
671
672 let weight0 = *balancer.metrics[0].effective_weight.read().await;
673 let weight1 = *balancer.metrics[1].effective_weight.read().await;
674
675 assert!(
676 weight0 < weight1,
677 "High latency target should have lower weight"
678 );
679 }
680}