sentinel_proxy/upstream/
adaptive.rs

1use async_trait::async_trait;
2use std::collections::HashMap;
3use std::sync::atomic::{AtomicU64, Ordering};
4use std::sync::Arc;
5use std::time::{Duration, Instant};
6use tokio::sync::RwLock;
7use tracing::{debug, info, warn};
8
9use super::{LoadBalancer, RequestContext, TargetSelection, UpstreamTarget};
10use sentinel_common::errors::{SentinelError, SentinelResult};
11
12/// Configuration for adaptive load balancing
13#[derive(Debug, Clone)]
14pub struct AdaptiveConfig {
15    /// Weight adjustment interval
16    pub adjustment_interval: Duration,
17    /// Minimum weight (percentage of original)
18    pub min_weight_ratio: f64,
19    /// Maximum weight (percentage of original)
20    pub max_weight_ratio: f64,
21    /// Error rate threshold for degradation
22    pub error_threshold: f64,
23    /// Latency threshold for degradation (p99)
24    pub latency_threshold: Duration,
25    /// EWMA decay factor (0.0 to 1.0, higher = more recent weight)
26    pub ewma_decay: f64,
27    /// Recovery rate when target improves
28    pub recovery_rate: f64,
29    /// Penalty rate when target degrades
30    pub penalty_rate: f64,
31    /// Enable circuit breaker integration
32    pub circuit_breaker: bool,
33    /// Minimum requests before adjusting weights
34    pub min_requests: u64,
35}
36
37impl Default for AdaptiveConfig {
38    fn default() -> Self {
39        Self {
40            adjustment_interval: Duration::from_secs(10),
41            min_weight_ratio: 0.1, // Can go down to 10% of original weight
42            max_weight_ratio: 2.0, // Can go up to 200% of original weight
43            error_threshold: 0.05, // 5% error rate triggers penalty
44            latency_threshold: Duration::from_millis(500),
45            ewma_decay: 0.8,    // Recent data weighted at 80%
46            recovery_rate: 1.1, // 10% recovery per interval
47            penalty_rate: 0.7,  // 30% penalty per interval
48            circuit_breaker: true,
49            min_requests: 100,
50        }
51    }
52}
53
54/// Performance metrics for a target with EWMA smoothing
55#[derive(Debug, Clone)]
56struct PerformanceMetrics {
57    /// Total requests
58    total_requests: Arc<AtomicU64>,
59    /// Failed requests
60    failed_requests: Arc<AtomicU64>,
61    /// Sum of latencies in microseconds
62    total_latency_us: Arc<AtomicU64>,
63    /// Success count for latency calculation
64    success_count: Arc<AtomicU64>,
65    /// Active connections
66    active_connections: Arc<AtomicU64>,
67    /// Current effective weight
68    effective_weight: Arc<RwLock<f64>>,
69    /// EWMA error rate
70    ewma_error_rate: Arc<RwLock<f64>>,
71    /// EWMA latency in microseconds
72    ewma_latency: Arc<RwLock<f64>>,
73    /// Last adjustment time
74    last_adjustment: Arc<RwLock<Instant>>,
75    /// Consecutive successes
76    consecutive_successes: Arc<AtomicU64>,
77    /// Consecutive failures
78    consecutive_failures: Arc<AtomicU64>,
79    /// Circuit breaker state
80    circuit_open: Arc<RwLock<bool>>,
81    /// Last error time
82    last_error: Arc<RwLock<Option<Instant>>>,
83}
84
85impl PerformanceMetrics {
86    fn new(initial_weight: f64) -> Self {
87        Self {
88            total_requests: Arc::new(AtomicU64::new(0)),
89            failed_requests: Arc::new(AtomicU64::new(0)),
90            total_latency_us: Arc::new(AtomicU64::new(0)),
91            success_count: Arc::new(AtomicU64::new(0)),
92            active_connections: Arc::new(AtomicU64::new(0)),
93            effective_weight: Arc::new(RwLock::new(initial_weight)),
94            ewma_error_rate: Arc::new(RwLock::new(0.0)),
95            ewma_latency: Arc::new(RwLock::new(0.0)),
96            last_adjustment: Arc::new(RwLock::new(Instant::now())),
97            consecutive_successes: Arc::new(AtomicU64::new(0)),
98            consecutive_failures: Arc::new(AtomicU64::new(0)),
99            circuit_open: Arc::new(RwLock::new(false)),
100            last_error: Arc::new(RwLock::new(None)),
101        }
102    }
103
104    /// Update EWMA values with new sample
105    async fn update_ewma(&self, error_rate: f64, latency_us: f64, decay: f64) {
106        let mut ewma_error = self.ewma_error_rate.write().await;
107        *ewma_error = decay * error_rate + (1.0 - decay) * (*ewma_error);
108
109        let mut ewma_lat = self.ewma_latency.write().await;
110        *ewma_lat = decay * latency_us + (1.0 - decay) * (*ewma_lat);
111    }
112
113    /// Record a request result
114    async fn record_result(
115        &self,
116        success: bool,
117        latency: Option<Duration>,
118        config: &AdaptiveConfig,
119    ) {
120        self.total_requests.fetch_add(1, Ordering::Relaxed);
121
122        if success {
123            self.consecutive_successes.fetch_add(1, Ordering::Relaxed);
124            self.consecutive_failures.store(0, Ordering::Relaxed);
125
126            if let Some(lat) = latency {
127                let lat_us = lat.as_micros() as u64;
128                self.total_latency_us.fetch_add(lat_us, Ordering::Relaxed);
129                self.success_count.fetch_add(1, Ordering::Relaxed);
130            }
131
132            // Check for circuit breaker recovery
133            if config.circuit_breaker {
134                let successes = self.consecutive_successes.load(Ordering::Relaxed);
135                if successes >= 5 && *self.circuit_open.read().await {
136                    *self.circuit_open.write().await = false;
137                    info!(
138                        "Circuit breaker closed after {} consecutive successes",
139                        successes
140                    );
141                }
142            }
143        } else {
144            self.failed_requests.fetch_add(1, Ordering::Relaxed);
145            self.consecutive_failures.fetch_add(1, Ordering::Relaxed);
146            self.consecutive_successes.store(0, Ordering::Relaxed);
147            *self.last_error.write().await = Some(Instant::now());
148
149            // Check for circuit breaker trip
150            if config.circuit_breaker {
151                let failures = self.consecutive_failures.load(Ordering::Relaxed);
152                if failures >= 5 && !*self.circuit_open.read().await {
153                    *self.circuit_open.write().await = true;
154                    warn!(
155                        "Circuit breaker opened after {} consecutive failures",
156                        failures
157                    );
158                }
159            }
160        }
161    }
162
163    /// Calculate current error rate
164    fn current_error_rate(&self) -> f64 {
165        let total = self.total_requests.load(Ordering::Relaxed);
166        if total == 0 {
167            return 0.0;
168        }
169        let failed = self.failed_requests.load(Ordering::Relaxed);
170        failed as f64 / total as f64
171    }
172
173    /// Calculate average latency
174    fn average_latency(&self) -> Duration {
175        let count = self.success_count.load(Ordering::Relaxed);
176        if count == 0 {
177            return Duration::ZERO;
178        }
179        let total_us = self.total_latency_us.load(Ordering::Relaxed);
180        Duration::from_micros(total_us / count)
181    }
182
183    /// Reset interval metrics
184    fn reset_interval_metrics(&self) {
185        self.total_requests.store(0, Ordering::Relaxed);
186        self.failed_requests.store(0, Ordering::Relaxed);
187        self.total_latency_us.store(0, Ordering::Relaxed);
188        self.success_count.store(0, Ordering::Relaxed);
189    }
190}
191
192/// Score calculation for target selection
193#[derive(Debug, Clone)]
194struct TargetScore {
195    index: usize,
196    score: f64,
197    weight: f64,
198}
199
200/// Adaptive load balancer that adjusts weights based on performance
201pub struct AdaptiveBalancer {
202    /// Configuration
203    config: AdaptiveConfig,
204    /// All upstream targets
205    targets: Vec<UpstreamTarget>,
206    /// Original weights (for ratio calculation)
207    original_weights: Vec<f64>,
208    /// Performance metrics per target
209    metrics: Vec<PerformanceMetrics>,
210    /// Target health status
211    health_status: Arc<RwLock<HashMap<String, bool>>>,
212    /// Last global adjustment time
213    last_global_adjustment: Arc<RwLock<Instant>>,
214}
215
216impl AdaptiveBalancer {
217    pub fn new(targets: Vec<UpstreamTarget>, config: AdaptiveConfig) -> Self {
218        let original_weights: Vec<f64> = targets.iter().map(|t| t.weight as f64).collect();
219        let metrics = original_weights
220            .iter()
221            .map(|&w| PerformanceMetrics::new(w))
222            .collect();
223
224        Self {
225            config,
226            targets,
227            original_weights,
228            metrics,
229            health_status: Arc::new(RwLock::new(HashMap::new())),
230            last_global_adjustment: Arc::new(RwLock::new(Instant::now())),
231        }
232    }
233
234    /// Adjust weights based on recent performance
235    async fn adjust_weights(&self) {
236        let mut last_adjustment = self.last_global_adjustment.write().await;
237
238        if last_adjustment.elapsed() < self.config.adjustment_interval {
239            return;
240        }
241
242        debug!("Adjusting weights based on performance metrics");
243
244        for (i, metric) in self.metrics.iter().enumerate() {
245            let requests = metric.total_requests.load(Ordering::Relaxed);
246
247            // Skip if insufficient data
248            if requests < self.config.min_requests {
249                continue;
250            }
251
252            // Calculate current metrics
253            let error_rate = metric.current_error_rate();
254            let avg_latency = metric.average_latency();
255            let latency_us = avg_latency.as_micros() as f64;
256
257            // Update EWMA
258            metric
259                .update_ewma(error_rate, latency_us, self.config.ewma_decay)
260                .await;
261
262            // Get smoothed metrics
263            let ewma_error = *metric.ewma_error_rate.read().await;
264            let ewma_latency_us = *metric.ewma_latency.read().await;
265            let ewma_latency = Duration::from_micros(ewma_latency_us as u64);
266
267            // Calculate weight adjustment factor
268            let mut adjustment = 1.0;
269
270            // Penalize high error rates
271            if ewma_error > self.config.error_threshold {
272                let error_factor =
273                    1.0 - ((ewma_error - self.config.error_threshold) * 10.0).min(0.9);
274                adjustment *= error_factor;
275                debug!(
276                    "Target {} error rate {:.2}% exceeds threshold, factor: {:.2}",
277                    i,
278                    ewma_error * 100.0,
279                    error_factor
280                );
281            }
282
283            // Penalize high latencies
284            if ewma_latency > self.config.latency_threshold {
285                let latency_ratio =
286                    self.config.latency_threshold.as_micros() as f64 / ewma_latency_us;
287                adjustment *= latency_ratio.max(0.1);
288                debug!(
289                    "Target {} latency {:?} exceeds threshold, factor: {:.2}",
290                    i, ewma_latency, latency_ratio
291                );
292            }
293
294            // Apply adjustment with damping
295            let mut current_weight = *metric.effective_weight.read().await;
296            let original = self.original_weights[i];
297
298            if adjustment < 1.0 {
299                // Degrade weight
300                current_weight *=
301                    self.config.penalty_rate + (1.0 - self.config.penalty_rate) * adjustment;
302            } else {
303                // Recover weight
304                current_weight *= self.config.recovery_rate;
305            }
306
307            // Apply bounds
308            let min_weight = original * self.config.min_weight_ratio;
309            let max_weight = original * self.config.max_weight_ratio;
310            current_weight = current_weight.max(min_weight).min(max_weight);
311
312            *metric.effective_weight.write().await = current_weight;
313
314            info!(
315                "Adjusted weight for target {}: {:.2} (original: {:.2}, error: {:.2}%, latency: {:.0}ms)",
316                i,
317                current_weight,
318                original,
319                ewma_error * 100.0,
320                ewma_latency.as_millis()
321            );
322
323            // Reset interval metrics
324            metric.reset_interval_metrics();
325        }
326
327        *last_adjustment = Instant::now();
328    }
329
330    /// Calculate scores for all healthy targets
331    async fn calculate_scores(&self) -> Vec<TargetScore> {
332        let health = self.health_status.read().await;
333        let mut scores = Vec::new();
334
335        for (i, target) in self.targets.iter().enumerate() {
336            let target_id = format!("{}:{}", target.address, target.port);
337            let is_healthy = health.get(&target_id).copied().unwrap_or(true);
338
339            // Skip unhealthy or circuit-broken targets
340            if !is_healthy || *self.metrics[i].circuit_open.read().await {
341                continue;
342            }
343
344            let weight = *self.metrics[i].effective_weight.read().await;
345            let connections = self.metrics[i].active_connections.load(Ordering::Relaxed) as f64;
346            let ewma_error = *self.metrics[i].ewma_error_rate.read().await;
347            let ewma_latency = *self.metrics[i].ewma_latency.read().await / 1000.0; // Convert to ms
348
349            // Score formula: weight / (1 + connections + error_penalty + latency_penalty)
350            let error_penalty = ewma_error * 100.0; // Scale error rate
351            let latency_penalty = (ewma_latency / 10.0).max(0.0); // Normalize latency
352            let score = weight / (1.0 + connections + error_penalty + latency_penalty);
353
354            scores.push(TargetScore {
355                index: i,
356                score,
357                weight,
358            });
359        }
360
361        // Sort by score (highest first)
362        scores.sort_by(|a, b| {
363            b.score
364                .partial_cmp(&a.score)
365                .unwrap_or(std::cmp::Ordering::Equal)
366        });
367
368        scores
369    }
370
371    /// Select target using weighted random selection based on scores
372    async fn weighted_select(&self, scores: &[TargetScore]) -> Option<usize> {
373        if scores.is_empty() {
374            return None;
375        }
376
377        // Calculate total score
378        let total_score: f64 = scores.iter().map(|s| s.score).sum();
379        if total_score <= 0.0 {
380            return Some(scores[0].index); // Fallback to first
381        }
382
383        // Weighted random selection
384        use rand::prelude::*;
385        let mut rng = thread_rng();
386        let threshold = rng.gen::<f64>() * total_score;
387
388        let mut cumulative = 0.0;
389        for score in scores {
390            cumulative += score.score;
391            if cumulative >= threshold {
392                return Some(score.index);
393            }
394        }
395
396        // Fallback for floating point edge case - scores is guaranteed non-empty here
397        scores.last().map(|s| s.index)
398    }
399}
400
401#[async_trait]
402impl LoadBalancer for AdaptiveBalancer {
403    async fn select(&self, _context: Option<&RequestContext>) -> SentinelResult<TargetSelection> {
404        // Periodically adjust weights
405        self.adjust_weights().await;
406
407        // Calculate scores for all targets
408        let scores = self.calculate_scores().await;
409
410        if scores.is_empty() {
411            return Err(SentinelError::NoHealthyUpstream);
412        }
413
414        // Select target based on scores
415        let target_index = self
416            .weighted_select(&scores)
417            .await
418            .ok_or(SentinelError::NoHealthyUpstream)?;
419
420        let target = &self.targets[target_index];
421        let metrics = &self.metrics[target_index];
422
423        // Track connection
424        metrics.active_connections.fetch_add(1, Ordering::Relaxed);
425
426        let effective_weight = *metrics.effective_weight.read().await;
427        let ewma_error = *metrics.ewma_error_rate.read().await;
428        let ewma_latency = Duration::from_micros(*metrics.ewma_latency.read().await as u64);
429
430        debug!(
431            "Adaptive selected target {} with score {:.2}, weight {:.2}",
432            target_index,
433            scores
434                .iter()
435                .find(|s| s.index == target_index)
436                .map(|s| s.score)
437                .unwrap_or(0.0),
438            effective_weight
439        );
440
441        Ok(TargetSelection {
442            address: format!("{}:{}", target.address, target.port),
443            weight: target.weight,
444            metadata: {
445                let mut meta = HashMap::new();
446                meta.insert("algorithm".to_string(), "adaptive".to_string());
447                meta.insert("target_index".to_string(), target_index.to_string());
448                meta.insert(
449                    "effective_weight".to_string(),
450                    format!("{:.2}", effective_weight),
451                );
452                meta.insert(
453                    "original_weight".to_string(),
454                    self.original_weights[target_index].to_string(),
455                );
456                meta.insert("error_rate".to_string(), format!("{:.4}", ewma_error));
457                meta.insert(
458                    "latency_ms".to_string(),
459                    format!("{:.2}", ewma_latency.as_millis()),
460                );
461                meta.insert(
462                    "connections".to_string(),
463                    metrics
464                        .active_connections
465                        .load(Ordering::Relaxed)
466                        .to_string(),
467                );
468                meta
469            },
470        })
471    }
472
473    async fn report_health(&self, address: &str, healthy: bool) {
474        let mut health = self.health_status.write().await;
475        let previous = health.insert(address.to_string(), healthy);
476
477        if previous != Some(healthy) {
478            info!(
479                "Adaptive: Target {} health changed from {:?} to {}",
480                address, previous, healthy
481            );
482
483            // Find target index and reset its weight on health change
484            for (i, target) in self.targets.iter().enumerate() {
485                let target_id = format!("{}:{}", target.address, target.port);
486                if target_id == address {
487                    if healthy {
488                        // Reset to original weight on recovery
489                        let original = self.original_weights[i];
490                        *self.metrics[i].effective_weight.write().await = original;
491                        *self.metrics[i].circuit_open.write().await = false;
492                        self.metrics[i]
493                            .consecutive_failures
494                            .store(0, Ordering::Relaxed);
495                        info!(
496                            "Reset target {} to original weight {:.2} on recovery",
497                            i, original
498                        );
499                    }
500                    break;
501                }
502            }
503        }
504    }
505
506    async fn healthy_targets(&self) -> Vec<String> {
507        let health = self.health_status.read().await;
508        let mut targets = Vec::new();
509
510        for (i, target) in self.targets.iter().enumerate() {
511            let target_id = format!("{}:{}", target.address, target.port);
512            let is_healthy = health.get(&target_id).copied().unwrap_or(true);
513            let circuit_open = *self.metrics[i].circuit_open.read().await;
514
515            if is_healthy && !circuit_open {
516                targets.push(target_id);
517            }
518        }
519
520        targets
521    }
522
523    async fn release(&self, selection: &TargetSelection) {
524        if let Some(index_str) = selection.metadata.get("target_index") {
525            if let Ok(index) = index_str.parse::<usize>() {
526                self.metrics[index]
527                    .active_connections
528                    .fetch_sub(1, Ordering::Relaxed);
529            }
530        }
531    }
532
533    async fn report_result(
534        &self,
535        selection: &TargetSelection,
536        success: bool,
537        latency: Option<Duration>,
538    ) {
539        if let Some(index_str) = selection.metadata.get("target_index") {
540            if let Ok(index) = index_str.parse::<usize>() {
541                self.metrics[index]
542                    .record_result(success, latency, &self.config)
543                    .await;
544            }
545        }
546    }
547}
548
549#[cfg(test)]
550mod tests {
551    use super::*;
552
553    fn create_test_targets(count: usize) -> Vec<UpstreamTarget> {
554        (0..count)
555            .map(|i| UpstreamTarget {
556                address: format!("10.0.0.{}", i + 1),
557                port: 8080,
558                weight: 100,
559            })
560            .collect()
561    }
562
563    #[tokio::test]
564    async fn test_weight_degradation() {
565        let targets = create_test_targets(3);
566        let config = AdaptiveConfig {
567            adjustment_interval: Duration::from_millis(10),
568            min_requests: 1,
569            ..Default::default()
570        };
571        let balancer = AdaptiveBalancer::new(targets, config);
572
573        // Simulate errors on target 0
574        for _ in 0..10 {
575            balancer.metrics[0]
576                .record_result(false, None, &balancer.config)
577                .await;
578        }
579        balancer.metrics[0]
580            .total_requests
581            .store(10, Ordering::Relaxed);
582
583        // Simulate success on target 1
584        for _ in 0..10 {
585            balancer.metrics[1]
586                .record_result(true, Some(Duration::from_millis(10)), &balancer.config)
587                .await;
588        }
589        balancer.metrics[1]
590            .total_requests
591            .store(10, Ordering::Relaxed);
592
593        // Wait for adjustment interval
594        tokio::time::sleep(Duration::from_millis(15)).await;
595
596        // Trigger weight adjustment
597        balancer.adjust_weights().await;
598
599        // Check that target 0 has degraded weight
600        let weight0 = *balancer.metrics[0].effective_weight.read().await;
601        let weight1 = *balancer.metrics[1].effective_weight.read().await;
602
603        assert!(weight0 < 100.0, "Target 0 weight should be degraded");
604        assert!(weight1 >= 100.0, "Target 1 weight should not be degraded");
605    }
606
607    #[tokio::test]
608    async fn test_circuit_breaker() {
609        let targets = create_test_targets(2);
610        let config = AdaptiveConfig::default();
611        let balancer = AdaptiveBalancer::new(targets, config);
612
613        // Simulate consecutive failures
614        for _ in 0..5 {
615            balancer.metrics[0]
616                .record_result(false, None, &balancer.config)
617                .await;
618        }
619
620        // Circuit should be open
621        assert!(*balancer.metrics[0].circuit_open.read().await);
622
623        // Should not select circuit-broken target
624        let scores = balancer.calculate_scores().await;
625        assert!(!scores.iter().any(|s| s.index == 0));
626
627        // Simulate recovery
628        for _ in 0..5 {
629            balancer.metrics[0]
630                .record_result(true, Some(Duration::from_millis(10)), &balancer.config)
631                .await;
632        }
633
634        // Circuit should be closed
635        assert!(!*balancer.metrics[0].circuit_open.read().await);
636    }
637
638    #[tokio::test]
639    async fn test_latency_penalty() {
640        let targets = create_test_targets(2);
641        let config = AdaptiveConfig {
642            adjustment_interval: Duration::from_millis(10),
643            min_requests: 1,
644            latency_threshold: Duration::from_millis(100),
645            ..Default::default()
646        };
647        let balancer = AdaptiveBalancer::new(targets, config);
648
649        // Simulate high latency on target 0
650        for _ in 0..10 {
651            balancer.metrics[0]
652                .record_result(true, Some(Duration::from_millis(500)), &balancer.config)
653                .await;
654        }
655        balancer.metrics[0]
656            .total_requests
657            .store(10, Ordering::Relaxed);
658
659        // Simulate normal latency on target 1
660        for _ in 0..10 {
661            balancer.metrics[1]
662                .record_result(true, Some(Duration::from_millis(50)), &balancer.config)
663                .await;
664        }
665        balancer.metrics[1]
666            .total_requests
667            .store(10, Ordering::Relaxed);
668
669        tokio::time::sleep(Duration::from_millis(15)).await;
670        balancer.adjust_weights().await;
671
672        let weight0 = *balancer.metrics[0].effective_weight.read().await;
673        let weight1 = *balancer.metrics[1].effective_weight.read().await;
674
675        assert!(
676            weight0 < weight1,
677            "High latency target should have lower weight"
678        );
679    }
680}