1use async_trait::async_trait;
15use std::collections::HashMap;
16use std::sync::atomic::{AtomicU64, Ordering};
17use std::sync::Arc;
18use std::time::{Duration, Instant};
19use tokio::sync::RwLock;
20use tracing::{debug, trace, warn};
21
22use sentinel_common::errors::{SentinelError, SentinelResult};
23
24use super::{LoadBalancer, RequestContext, TargetSelection, UpstreamTarget};
25
26#[derive(Debug, Clone)]
28pub struct PeakEwmaConfig {
29 pub decay_time: Duration,
32 pub initial_latency: Duration,
34 pub load_penalty: f64,
37}
38
39impl Default for PeakEwmaConfig {
40 fn default() -> Self {
41 Self {
42 decay_time: Duration::from_secs(10),
43 initial_latency: Duration::from_millis(1),
44 load_penalty: 1.5,
45 }
46 }
47}
48
49struct TargetStats {
51 ewma_ns: AtomicU64,
53 last_latency_ns: AtomicU64,
55 last_update_ns: AtomicU64,
57 active_connections: AtomicU64,
59 epoch: Instant,
61}
62
63impl TargetStats {
64 fn new(initial_latency: Duration) -> Self {
65 let initial_ns = initial_latency.as_nanos() as u64;
66 Self {
67 ewma_ns: AtomicU64::new(initial_ns),
68 last_latency_ns: AtomicU64::new(initial_ns),
69 last_update_ns: AtomicU64::new(0),
70 active_connections: AtomicU64::new(0),
71 epoch: Instant::now(),
72 }
73 }
74
75 fn update(&self, latency: Duration, decay_time: Duration) {
77 let latency_ns = latency.as_nanos() as u64;
78 let now_ns = self.epoch.elapsed().as_nanos() as u64;
79 let last_update = self.last_update_ns.load(Ordering::Relaxed);
80
81 let elapsed_ns = now_ns.saturating_sub(last_update);
83 let decay = (-((elapsed_ns as f64) / (decay_time.as_nanos() as f64))).exp();
84
85 let old_ewma = self.ewma_ns.load(Ordering::Relaxed);
87 let new_ewma = ((old_ewma as f64) * decay + (latency_ns as f64) * (1.0 - decay)) as u64;
88
89 self.ewma_ns.store(new_ewma, Ordering::Relaxed);
90 self.last_latency_ns.store(latency_ns, Ordering::Relaxed);
91 self.last_update_ns.store(now_ns, Ordering::Relaxed);
92 }
93
94 fn peak_latency_ns(&self) -> u64 {
96 let ewma = self.ewma_ns.load(Ordering::Relaxed);
97 let last = self.last_latency_ns.load(Ordering::Relaxed);
98 ewma.max(last)
99 }
100
101 fn load_score(&self, load_penalty: f64) -> f64 {
103 let latency = self.peak_latency_ns() as f64;
104 let active = self.active_connections.load(Ordering::Relaxed) as f64;
105 latency * (1.0 + active * load_penalty)
106 }
107
108 fn increment_connections(&self) {
109 self.active_connections.fetch_add(1, Ordering::Relaxed);
110 }
111
112 fn decrement_connections(&self) {
113 self.active_connections.fetch_sub(1, Ordering::Relaxed);
114 }
115}
116
117pub struct PeakEwmaBalancer {
119 targets: Vec<UpstreamTarget>,
121 stats: HashMap<String, Arc<TargetStats>>,
123 health_status: Arc<RwLock<HashMap<String, bool>>>,
125 config: PeakEwmaConfig,
127}
128
129impl PeakEwmaBalancer {
130 pub fn new(targets: Vec<UpstreamTarget>, config: PeakEwmaConfig) -> Self {
132 let mut health_status = HashMap::new();
133 let mut stats = HashMap::new();
134
135 for target in &targets {
136 let addr = target.full_address();
137 health_status.insert(addr.clone(), true);
138 stats.insert(addr, Arc::new(TargetStats::new(config.initial_latency)));
139 }
140
141 Self {
142 targets,
143 stats,
144 health_status: Arc::new(RwLock::new(health_status)),
145 config,
146 }
147 }
148}
149
150#[async_trait]
151impl LoadBalancer for PeakEwmaBalancer {
152 async fn select(&self, _context: Option<&RequestContext>) -> SentinelResult<TargetSelection> {
153 trace!(
154 total_targets = self.targets.len(),
155 algorithm = "peak_ewma",
156 "Selecting upstream target"
157 );
158
159 let health = self.health_status.read().await;
160 let healthy_targets: Vec<_> = self
161 .targets
162 .iter()
163 .filter(|t| *health.get(&t.full_address()).unwrap_or(&true))
164 .collect();
165 drop(health);
166
167 if healthy_targets.is_empty() {
168 warn!(
169 total_targets = self.targets.len(),
170 algorithm = "peak_ewma",
171 "No healthy upstream targets available"
172 );
173 return Err(SentinelError::NoHealthyUpstream);
174 }
175
176 let mut best_target = None;
178 let mut best_score = f64::MAX;
179
180 for target in &healthy_targets {
181 let addr = target.full_address();
182 if let Some(stats) = self.stats.get(&addr) {
183 let score = stats.load_score(self.config.load_penalty);
184 trace!(
185 target = %addr,
186 score = score,
187 ewma_ns = stats.ewma_ns.load(Ordering::Relaxed),
188 active_connections = stats.active_connections.load(Ordering::Relaxed),
189 "Evaluating target load score"
190 );
191 if score < best_score {
192 best_score = score;
193 best_target = Some(target);
194 }
195 }
196 }
197
198 let target = best_target.ok_or(SentinelError::NoHealthyUpstream)?;
199
200 if let Some(stats) = self.stats.get(&target.full_address()) {
202 stats.increment_connections();
203 }
204
205 trace!(
206 selected_target = %target.full_address(),
207 load_score = best_score,
208 healthy_count = healthy_targets.len(),
209 algorithm = "peak_ewma",
210 "Selected target via Peak EWMA"
211 );
212
213 Ok(TargetSelection {
214 address: target.full_address(),
215 weight: target.weight,
216 metadata: HashMap::new(),
217 })
218 }
219
220 async fn release(&self, selection: &TargetSelection) {
221 if let Some(stats) = self.stats.get(&selection.address) {
222 stats.decrement_connections();
223 trace!(
224 target = %selection.address,
225 active_connections = stats.active_connections.load(Ordering::Relaxed),
226 algorithm = "peak_ewma",
227 "Released connection"
228 );
229 }
230 }
231
232 async fn report_result(
233 &self,
234 selection: &TargetSelection,
235 success: bool,
236 latency: Option<Duration>,
237 ) {
238 self.release(selection).await;
240
241 if let Some(latency) = latency {
243 if let Some(stats) = self.stats.get(&selection.address) {
244 stats.update(latency, self.config.decay_time);
245 trace!(
246 target = %selection.address,
247 latency_ms = latency.as_millis(),
248 new_ewma_ns = stats.ewma_ns.load(Ordering::Relaxed),
249 algorithm = "peak_ewma",
250 "Updated EWMA latency"
251 );
252 }
253 }
254
255 if !success {
257 self.report_health(&selection.address, false).await;
258 }
259 }
260
261 async fn report_result_with_latency(
262 &self,
263 address: &str,
264 success: bool,
265 latency: Option<Duration>,
266 ) {
267 if let Some(latency) = latency {
269 if let Some(stats) = self.stats.get(address) {
270 stats.update(latency, self.config.decay_time);
271 debug!(
272 target = %address,
273 latency_ms = latency.as_millis(),
274 new_ewma_ns = stats.ewma_ns.load(Ordering::Relaxed),
275 algorithm = "peak_ewma",
276 "Updated EWMA latency via report_result_with_latency"
277 );
278 }
279 }
280
281 self.report_health(address, success).await;
283 }
284
285 async fn report_health(&self, address: &str, healthy: bool) {
286 trace!(
287 target = %address,
288 healthy = healthy,
289 algorithm = "peak_ewma",
290 "Updating target health status"
291 );
292 self.health_status
293 .write()
294 .await
295 .insert(address.to_string(), healthy);
296 }
297
298 async fn healthy_targets(&self) -> Vec<String> {
299 self.health_status
300 .read()
301 .await
302 .iter()
303 .filter_map(|(addr, &healthy)| if healthy { Some(addr.clone()) } else { None })
304 .collect()
305 }
306}
307
308#[cfg(test)]
309mod tests {
310 use super::*;
311
312 fn make_targets(count: usize) -> Vec<UpstreamTarget> {
313 (0..count)
314 .map(|i| UpstreamTarget::new(format!("backend-{}", i), 8080, 100))
315 .collect()
316 }
317
318 #[tokio::test]
319 async fn test_selects_lowest_latency() {
320 let targets = make_targets(3);
321 let balancer = PeakEwmaBalancer::new(targets, PeakEwmaConfig::default());
322
323 let addr0 = "backend-0:8080".to_string();
325 let addr1 = "backend-1:8080".to_string();
326 let addr2 = "backend-2:8080".to_string();
327
328 balancer.stats.get(&addr0).unwrap().update(Duration::from_millis(100), Duration::from_secs(10));
330 balancer.stats.get(&addr1).unwrap().update(Duration::from_millis(10), Duration::from_secs(10));
331 balancer.stats.get(&addr2).unwrap().update(Duration::from_millis(50), Duration::from_secs(10));
332
333 let selection = balancer.select(None).await.unwrap();
335 assert_eq!(selection.address, addr1);
336 }
337
338 #[tokio::test]
339 async fn test_considers_active_connections() {
340 let targets = make_targets(2);
341 let balancer = PeakEwmaBalancer::new(targets, PeakEwmaConfig::default());
342
343 let addr0 = "backend-0:8080".to_string();
344 let addr1 = "backend-1:8080".to_string();
345
346 balancer.stats.get(&addr0).unwrap().update(Duration::from_millis(10), Duration::from_secs(10));
348 balancer.stats.get(&addr1).unwrap().update(Duration::from_millis(10), Duration::from_secs(10));
349
350 for _ in 0..5 {
352 balancer.stats.get(&addr0).unwrap().increment_connections();
353 }
354
355 let selection = balancer.select(None).await.unwrap();
357 assert_eq!(selection.address, addr1);
358 }
359
360 #[tokio::test]
361 async fn test_ewma_decay() {
362 let targets = make_targets(1);
363 let config = PeakEwmaConfig {
364 decay_time: Duration::from_millis(100),
365 initial_latency: Duration::from_millis(50), load_penalty: 1.5,
367 };
368 let balancer = PeakEwmaBalancer::new(targets, config);
369
370 let addr = "backend-0:8080".to_string();
371 let stats = balancer.stats.get(&addr).unwrap();
372
373 tokio::time::sleep(Duration::from_millis(50)).await;
375
376 stats.update(Duration::from_millis(100), Duration::from_millis(100));
378 let after_high = stats.ewma_ns.load(Ordering::Relaxed);
379
380 tokio::time::sleep(Duration::from_millis(200)).await;
382 stats.update(Duration::from_millis(10), Duration::from_millis(100));
383 let after_low = stats.ewma_ns.load(Ordering::Relaxed);
384
385 let low_latency_ns = Duration::from_millis(10).as_nanos() as u64;
390 let high_latency_ns = Duration::from_millis(100).as_nanos() as u64;
391
392 assert!(
394 after_low < high_latency_ns,
395 "EWMA after low update ({}) should be less than high latency ({})",
396 after_low,
397 high_latency_ns
398 );
399 assert!(
400 after_low > low_latency_ns,
401 "EWMA after low update ({}) should be greater than low latency ({}) due to some carry-over",
402 after_low,
403 low_latency_ns
404 );
405 }
406
407 #[tokio::test]
408 async fn test_connection_tracking() {
409 let targets = make_targets(1);
410 let balancer = PeakEwmaBalancer::new(targets, PeakEwmaConfig::default());
411
412 let selection = balancer.select(None).await.unwrap();
414 let stats = balancer.stats.get(&selection.address).unwrap();
415 assert_eq!(stats.active_connections.load(Ordering::Relaxed), 1);
416
417 balancer.release(&selection).await;
419 assert_eq!(stats.active_connections.load(Ordering::Relaxed), 0);
420 }
421}