sentinel_proxy/upstream/
peak_ewma.rs

1//! Peak EWMA load balancer
2//!
3//! Implements Twitter Finagle's Peak EWMA (Exponentially Weighted Moving Average)
4//! algorithm. This algorithm tracks the latency of each backend using an
5//! exponentially weighted moving average, and selects the backend with the
6//! lowest predicted completion time.
7//!
8//! The "peak" aspect means we use the maximum of:
9//! - Current EWMA latency
10//! - Most recent observed latency (to quickly react to latency spikes)
11//!
12//! Reference: https://twitter.github.io/finagle/guide/Clients.html#power-of-two-choices-p2c-least-loaded
13
14use async_trait::async_trait;
15use std::collections::HashMap;
16use std::sync::atomic::{AtomicU64, Ordering};
17use std::sync::Arc;
18use std::time::{Duration, Instant};
19use tokio::sync::RwLock;
20use tracing::{debug, trace, warn};
21
22use sentinel_common::errors::{SentinelError, SentinelResult};
23
24use super::{LoadBalancer, RequestContext, TargetSelection, UpstreamTarget};
25
26/// Configuration for Peak EWMA load balancer
27#[derive(Debug, Clone)]
28pub struct PeakEwmaConfig {
29    /// Decay time for EWMA calculation (default: 10 seconds)
30    /// Lower values make the algorithm more responsive to recent latency changes
31    pub decay_time: Duration,
32    /// Initial latency estimate for new backends (default: 1ms)
33    pub initial_latency: Duration,
34    /// Penalty multiplier for backends with active connections (default: 1.5)
35    /// Higher values favor backends with fewer connections
36    pub load_penalty: f64,
37}
38
39impl Default for PeakEwmaConfig {
40    fn default() -> Self {
41        Self {
42            decay_time: Duration::from_secs(10),
43            initial_latency: Duration::from_millis(1),
44            load_penalty: 1.5,
45        }
46    }
47}
48
49/// Per-target statistics for EWMA tracking
50struct TargetStats {
51    /// EWMA latency in nanoseconds
52    ewma_ns: AtomicU64,
53    /// Last observed latency in nanoseconds
54    last_latency_ns: AtomicU64,
55    /// Timestamp of last update (as nanos since some epoch)
56    last_update_ns: AtomicU64,
57    /// Number of active connections
58    active_connections: AtomicU64,
59    /// Epoch for relative timestamps
60    epoch: Instant,
61}
62
63impl TargetStats {
64    fn new(initial_latency: Duration) -> Self {
65        let initial_ns = initial_latency.as_nanos() as u64;
66        Self {
67            ewma_ns: AtomicU64::new(initial_ns),
68            last_latency_ns: AtomicU64::new(initial_ns),
69            last_update_ns: AtomicU64::new(0),
70            active_connections: AtomicU64::new(0),
71            epoch: Instant::now(),
72        }
73    }
74
75    /// Update EWMA with a new latency observation
76    fn update(&self, latency: Duration, decay_time: Duration) {
77        let latency_ns = latency.as_nanos() as u64;
78        let now_ns = self.epoch.elapsed().as_nanos() as u64;
79        let last_update = self.last_update_ns.load(Ordering::Relaxed);
80
81        // Calculate decay factor: e^(-elapsed / decay_time)
82        let elapsed_ns = now_ns.saturating_sub(last_update);
83        let decay = (-((elapsed_ns as f64) / (decay_time.as_nanos() as f64))).exp();
84
85        // EWMA update: new_ewma = old_ewma * decay + new_value * (1 - decay)
86        let old_ewma = self.ewma_ns.load(Ordering::Relaxed);
87        let new_ewma = ((old_ewma as f64) * decay + (latency_ns as f64) * (1.0 - decay)) as u64;
88
89        self.ewma_ns.store(new_ewma, Ordering::Relaxed);
90        self.last_latency_ns.store(latency_ns, Ordering::Relaxed);
91        self.last_update_ns.store(now_ns, Ordering::Relaxed);
92    }
93
94    /// Get the peak latency (max of EWMA and last observed)
95    fn peak_latency_ns(&self) -> u64 {
96        let ewma = self.ewma_ns.load(Ordering::Relaxed);
97        let last = self.last_latency_ns.load(Ordering::Relaxed);
98        ewma.max(last)
99    }
100
101    /// Calculate the load score (latency * (1 + active_connections * penalty))
102    fn load_score(&self, load_penalty: f64) -> f64 {
103        let latency = self.peak_latency_ns() as f64;
104        let active = self.active_connections.load(Ordering::Relaxed) as f64;
105        latency * (1.0 + active * load_penalty)
106    }
107
108    fn increment_connections(&self) {
109        self.active_connections.fetch_add(1, Ordering::Relaxed);
110    }
111
112    fn decrement_connections(&self) {
113        self.active_connections.fetch_sub(1, Ordering::Relaxed);
114    }
115}
116
117/// Peak EWMA load balancer
118pub struct PeakEwmaBalancer {
119    /// Original target list
120    targets: Vec<UpstreamTarget>,
121    /// Per-target statistics
122    stats: HashMap<String, Arc<TargetStats>>,
123    /// Health status per target
124    health_status: Arc<RwLock<HashMap<String, bool>>>,
125    /// Configuration
126    config: PeakEwmaConfig,
127}
128
129impl PeakEwmaBalancer {
130    /// Create a new Peak EWMA balancer
131    pub fn new(targets: Vec<UpstreamTarget>, config: PeakEwmaConfig) -> Self {
132        let mut health_status = HashMap::new();
133        let mut stats = HashMap::new();
134
135        for target in &targets {
136            let addr = target.full_address();
137            health_status.insert(addr.clone(), true);
138            stats.insert(addr, Arc::new(TargetStats::new(config.initial_latency)));
139        }
140
141        Self {
142            targets,
143            stats,
144            health_status: Arc::new(RwLock::new(health_status)),
145            config,
146        }
147    }
148}
149
150#[async_trait]
151impl LoadBalancer for PeakEwmaBalancer {
152    async fn select(&self, _context: Option<&RequestContext>) -> SentinelResult<TargetSelection> {
153        trace!(
154            total_targets = self.targets.len(),
155            algorithm = "peak_ewma",
156            "Selecting upstream target"
157        );
158
159        let health = self.health_status.read().await;
160        let healthy_targets: Vec<_> = self
161            .targets
162            .iter()
163            .filter(|t| *health.get(&t.full_address()).unwrap_or(&true))
164            .collect();
165        drop(health);
166
167        if healthy_targets.is_empty() {
168            warn!(
169                total_targets = self.targets.len(),
170                algorithm = "peak_ewma",
171                "No healthy upstream targets available"
172            );
173            return Err(SentinelError::NoHealthyUpstream);
174        }
175
176        // Find target with lowest load score
177        let mut best_target = None;
178        let mut best_score = f64::MAX;
179
180        for target in &healthy_targets {
181            let addr = target.full_address();
182            if let Some(stats) = self.stats.get(&addr) {
183                let score = stats.load_score(self.config.load_penalty);
184                trace!(
185                    target = %addr,
186                    score = score,
187                    ewma_ns = stats.ewma_ns.load(Ordering::Relaxed),
188                    active_connections = stats.active_connections.load(Ordering::Relaxed),
189                    "Evaluating target load score"
190                );
191                if score < best_score {
192                    best_score = score;
193                    best_target = Some(target);
194                }
195            }
196        }
197
198        let target = best_target.ok_or(SentinelError::NoHealthyUpstream)?;
199
200        // Increment active connections for selected target
201        if let Some(stats) = self.stats.get(&target.full_address()) {
202            stats.increment_connections();
203        }
204
205        trace!(
206            selected_target = %target.full_address(),
207            load_score = best_score,
208            healthy_count = healthy_targets.len(),
209            algorithm = "peak_ewma",
210            "Selected target via Peak EWMA"
211        );
212
213        Ok(TargetSelection {
214            address: target.full_address(),
215            weight: target.weight,
216            metadata: HashMap::new(),
217        })
218    }
219
220    async fn release(&self, selection: &TargetSelection) {
221        if let Some(stats) = self.stats.get(&selection.address) {
222            stats.decrement_connections();
223            trace!(
224                target = %selection.address,
225                active_connections = stats.active_connections.load(Ordering::Relaxed),
226                algorithm = "peak_ewma",
227                "Released connection"
228            );
229        }
230    }
231
232    async fn report_result(
233        &self,
234        selection: &TargetSelection,
235        success: bool,
236        latency: Option<Duration>,
237    ) {
238        // Release the connection
239        self.release(selection).await;
240
241        // Update EWMA if we have latency data
242        if let Some(latency) = latency {
243            if let Some(stats) = self.stats.get(&selection.address) {
244                stats.update(latency, self.config.decay_time);
245                trace!(
246                    target = %selection.address,
247                    latency_ms = latency.as_millis(),
248                    new_ewma_ns = stats.ewma_ns.load(Ordering::Relaxed),
249                    algorithm = "peak_ewma",
250                    "Updated EWMA latency"
251                );
252            }
253        }
254
255        // Update health if request failed
256        if !success {
257            self.report_health(&selection.address, false).await;
258        }
259    }
260
261    async fn report_result_with_latency(
262        &self,
263        address: &str,
264        success: bool,
265        latency: Option<Duration>,
266    ) {
267        // Update EWMA if we have latency data
268        if let Some(latency) = latency {
269            if let Some(stats) = self.stats.get(address) {
270                stats.update(latency, self.config.decay_time);
271                debug!(
272                    target = %address,
273                    latency_ms = latency.as_millis(),
274                    new_ewma_ns = stats.ewma_ns.load(Ordering::Relaxed),
275                    algorithm = "peak_ewma",
276                    "Updated EWMA latency via report_result_with_latency"
277                );
278            }
279        }
280
281        // Update health
282        self.report_health(address, success).await;
283    }
284
285    async fn report_health(&self, address: &str, healthy: bool) {
286        trace!(
287            target = %address,
288            healthy = healthy,
289            algorithm = "peak_ewma",
290            "Updating target health status"
291        );
292        self.health_status
293            .write()
294            .await
295            .insert(address.to_string(), healthy);
296    }
297
298    async fn healthy_targets(&self) -> Vec<String> {
299        self.health_status
300            .read()
301            .await
302            .iter()
303            .filter_map(|(addr, &healthy)| if healthy { Some(addr.clone()) } else { None })
304            .collect()
305    }
306}
307
308#[cfg(test)]
309mod tests {
310    use super::*;
311
312    fn make_targets(count: usize) -> Vec<UpstreamTarget> {
313        (0..count)
314            .map(|i| UpstreamTarget::new(format!("backend-{}", i), 8080, 100))
315            .collect()
316    }
317
318    #[tokio::test]
319    async fn test_selects_lowest_latency() {
320        let targets = make_targets(3);
321        let balancer = PeakEwmaBalancer::new(targets, PeakEwmaConfig::default());
322
323        // Simulate different latencies for each backend
324        let addr0 = "backend-0:8080".to_string();
325        let addr1 = "backend-1:8080".to_string();
326        let addr2 = "backend-2:8080".to_string();
327
328        // Update latencies: backend-1 has lowest
329        balancer.stats.get(&addr0).unwrap().update(Duration::from_millis(100), Duration::from_secs(10));
330        balancer.stats.get(&addr1).unwrap().update(Duration::from_millis(10), Duration::from_secs(10));
331        balancer.stats.get(&addr2).unwrap().update(Duration::from_millis(50), Duration::from_secs(10));
332
333        // Should select backend-1 (lowest latency)
334        let selection = balancer.select(None).await.unwrap();
335        assert_eq!(selection.address, addr1);
336    }
337
338    #[tokio::test]
339    async fn test_considers_active_connections() {
340        let targets = make_targets(2);
341        let balancer = PeakEwmaBalancer::new(targets, PeakEwmaConfig::default());
342
343        let addr0 = "backend-0:8080".to_string();
344        let addr1 = "backend-1:8080".to_string();
345
346        // Same latency, but backend-0 has active connections
347        balancer.stats.get(&addr0).unwrap().update(Duration::from_millis(10), Duration::from_secs(10));
348        balancer.stats.get(&addr1).unwrap().update(Duration::from_millis(10), Duration::from_secs(10));
349
350        // Add active connections to backend-0
351        for _ in 0..5 {
352            balancer.stats.get(&addr0).unwrap().increment_connections();
353        }
354
355        // Should select backend-1 (no active connections)
356        let selection = balancer.select(None).await.unwrap();
357        assert_eq!(selection.address, addr1);
358    }
359
360    #[tokio::test]
361    async fn test_ewma_decay() {
362        let targets = make_targets(1);
363        let config = PeakEwmaConfig {
364            decay_time: Duration::from_millis(100),
365            initial_latency: Duration::from_millis(50), // Start with 50ms
366            load_penalty: 1.5,
367        };
368        let balancer = PeakEwmaBalancer::new(targets, config);
369
370        let addr = "backend-0:8080".to_string();
371        let stats = balancer.stats.get(&addr).unwrap();
372
373        // Wait a bit so the first update has some elapsed time
374        tokio::time::sleep(Duration::from_millis(50)).await;
375
376        // Update with high latency
377        stats.update(Duration::from_millis(100), Duration::from_millis(100));
378        let after_high = stats.ewma_ns.load(Ordering::Relaxed);
379
380        // Wait for decay and update with low latency
381        tokio::time::sleep(Duration::from_millis(200)).await;
382        stats.update(Duration::from_millis(10), Duration::from_millis(100));
383        let after_low = stats.ewma_ns.load(Ordering::Relaxed);
384
385        // After the low latency update (with significant decay time),
386        // the EWMA should move toward the low value
387        // decay = e^(-200/100) = e^(-2) ≈ 0.135
388        // new_ewma ≈ old * 0.135 + 10ms * 0.865 ≈ mostly 10ms
389        let low_latency_ns = Duration::from_millis(10).as_nanos() as u64;
390        let high_latency_ns = Duration::from_millis(100).as_nanos() as u64;
391
392        // The after_low value should be between low and high, closer to low
393        assert!(
394            after_low < high_latency_ns,
395            "EWMA after low update ({}) should be less than high latency ({})",
396            after_low,
397            high_latency_ns
398        );
399        assert!(
400            after_low > low_latency_ns,
401            "EWMA after low update ({}) should be greater than low latency ({}) due to some carry-over",
402            after_low,
403            low_latency_ns
404        );
405    }
406
407    #[tokio::test]
408    async fn test_connection_tracking() {
409        let targets = make_targets(1);
410        let balancer = PeakEwmaBalancer::new(targets, PeakEwmaConfig::default());
411
412        // Select increments connections
413        let selection = balancer.select(None).await.unwrap();
414        let stats = balancer.stats.get(&selection.address).unwrap();
415        assert_eq!(stats.active_connections.load(Ordering::Relaxed), 1);
416
417        // Release decrements connections
418        balancer.release(&selection).await;
419        assert_eq!(stats.active_connections.load(Ordering::Relaxed), 0);
420    }
421}