sentinel_proxy/upstream/
p2c.rs

1use async_trait::async_trait;
2use rand::rngs::StdRng;
3use rand::{Rng, SeedableRng};
4use std::collections::HashMap;
5use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
6use std::sync::Arc;
7use std::time::{Duration, Instant};
8use tokio::sync::RwLock;
9
10use tracing::{debug, info};
11
12use super::{LoadBalancer, RequestContext, TargetSelection, UpstreamTarget};
13use sentinel_common::errors::{SentinelError, SentinelResult};
14
15/// Load metric type for P2C selection
16#[derive(Debug, Clone, Copy)]
17pub enum LoadMetric {
18    /// Active connection count
19    Connections,
20    /// Average response latency
21    Latency,
22    /// Combined score (connections * latency)
23    Combined,
24    /// CPU usage (requires external monitoring)
25    CpuUsage,
26    /// Request rate
27    RequestRate,
28}
29
30impl Default for LoadMetric {
31    fn default() -> Self {
32        LoadMetric::Connections
33    }
34}
35
36/// Configuration for P2C load balancer
37#[derive(Debug, Clone)]
38pub struct P2cConfig {
39    /// Load metric to use for selection
40    pub load_metric: LoadMetric,
41    /// Weight multiplier for secondary metric in combined mode
42    pub secondary_weight: f64,
43    /// Whether to use weighted random selection
44    pub use_weights: bool,
45    /// Latency window for averaging (in seconds)
46    pub latency_window_secs: u64,
47    /// Enable power of three choices for better distribution
48    pub power_of_three: bool,
49}
50
51impl Default for P2cConfig {
52    fn default() -> Self {
53        Self {
54            load_metric: LoadMetric::Connections,
55            secondary_weight: 0.5,
56            use_weights: true,
57            latency_window_secs: 10,
58            power_of_three: false,
59        }
60    }
61}
62
63/// Target metrics for load calculation
64#[derive(Debug, Clone)]
65struct TargetMetrics {
66    /// Active connections
67    connections: Arc<AtomicU64>,
68    /// Total requests
69    requests: Arc<AtomicU64>,
70    /// Total latency in microseconds
71    total_latency_us: Arc<AtomicU64>,
72    /// Request count for latency averaging
73    latency_count: Arc<AtomicU64>,
74    /// CPU usage percentage (0-100)
75    cpu_usage: Arc<AtomicU64>,
76    /// Last update time
77    last_update: Arc<RwLock<Instant>>,
78    /// Recent latency measurements (ring buffer)
79    recent_latencies: Arc<RwLock<Vec<Duration>>>,
80    /// Ring buffer position
81    latency_buffer_pos: Arc<AtomicUsize>,
82}
83
84impl TargetMetrics {
85    fn new(buffer_size: usize) -> Self {
86        Self {
87            connections: Arc::new(AtomicU64::new(0)),
88            requests: Arc::new(AtomicU64::new(0)),
89            total_latency_us: Arc::new(AtomicU64::new(0)),
90            latency_count: Arc::new(AtomicU64::new(0)),
91            cpu_usage: Arc::new(AtomicU64::new(0)),
92            last_update: Arc::new(RwLock::new(Instant::now())),
93            recent_latencies: Arc::new(RwLock::new(vec![Duration::ZERO; buffer_size])),
94            latency_buffer_pos: Arc::new(AtomicUsize::new(0)),
95        }
96    }
97
98    /// Calculate average latency over the window
99    async fn average_latency(&self) -> Duration {
100        let latencies = self.recent_latencies.read().await;
101        let count = self.latency_count.load(Ordering::Relaxed);
102
103        if count == 0 {
104            return Duration::ZERO;
105        }
106
107        let total: Duration = latencies.iter().sum();
108        let sample_count = count.min(latencies.len() as u64);
109
110        if sample_count > 0 {
111            total / sample_count as u32
112        } else {
113            Duration::ZERO
114        }
115    }
116
117    /// Record a latency measurement
118    async fn record_latency(&self, latency: Duration) {
119        let pos = self.latency_buffer_pos.fetch_add(1, Ordering::Relaxed);
120        let mut latencies = self.recent_latencies.write().await;
121        let buffer_size = latencies.len();
122        latencies[pos % buffer_size] = latency;
123
124        self.total_latency_us
125            .fetch_add(latency.as_micros() as u64, Ordering::Relaxed);
126        self.latency_count.fetch_add(1, Ordering::Relaxed);
127    }
128
129    /// Get current load based on metric type
130    async fn get_load(&self, metric: LoadMetric) -> f64 {
131        match metric {
132            LoadMetric::Connections => self.connections.load(Ordering::Relaxed) as f64,
133            LoadMetric::Latency => self.average_latency().await.as_micros() as f64,
134            LoadMetric::Combined => {
135                let connections = self.connections.load(Ordering::Relaxed) as f64;
136                let latency = self.average_latency().await.as_micros() as f64;
137                // Normalize latency to be on similar scale as connections
138                // (assuming avg latency ~10ms = 10000us, and avg connections ~100)
139                connections + (latency / 100.0)
140            }
141            LoadMetric::CpuUsage => self.cpu_usage.load(Ordering::Relaxed) as f64,
142            LoadMetric::RequestRate => {
143                // Calculate requests per second over the last update interval
144                let requests = self.requests.load(Ordering::Relaxed);
145                let last_update = *self.last_update.read().await;
146                let elapsed = last_update.elapsed().as_secs_f64();
147                if elapsed > 0.0 {
148                    requests as f64 / elapsed
149                } else {
150                    0.0
151                }
152            }
153        }
154    }
155}
156
157/// Power of Two Choices load balancer
158pub struct P2cBalancer {
159    /// Configuration
160    config: P2cConfig,
161    /// All upstream targets
162    targets: Vec<UpstreamTarget>,
163    /// Target health status
164    health_status: Arc<RwLock<HashMap<String, bool>>>,
165    /// Metrics per target
166    metrics: Vec<TargetMetrics>,
167    /// Random number generator (thread-safe)
168    rng: Arc<RwLock<StdRng>>,
169    /// Cumulative weights for weighted selection
170    cumulative_weights: Vec<u32>,
171}
172
173impl P2cBalancer {
174    pub fn new(targets: Vec<UpstreamTarget>, config: P2cConfig) -> Self {
175        let buffer_size = (config.latency_window_secs * 100) as usize; // 100 samples/sec
176        let metrics = targets
177            .iter()
178            .map(|_| TargetMetrics::new(buffer_size))
179            .collect();
180
181        // Calculate cumulative weights for weighted random selection
182        let mut cumulative_weights = Vec::with_capacity(targets.len());
183        let mut cumsum = 0u32;
184        for target in &targets {
185            cumsum += target.weight;
186            cumulative_weights.push(cumsum);
187        }
188
189        Self {
190            config,
191            targets,
192            health_status: Arc::new(RwLock::new(HashMap::new())),
193            metrics,
194            rng: Arc::new(RwLock::new(StdRng::from_entropy())),
195            cumulative_weights,
196        }
197    }
198
199    /// Select a random healthy target index
200    async fn random_healthy_target(&self) -> Option<usize> {
201        let health = self.health_status.read().await;
202        let healthy_indices: Vec<usize> = self
203            .targets
204            .iter()
205            .enumerate()
206            .filter_map(|(i, t)| {
207                let target_id = format!("{}:{}", t.address, t.port);
208                if health.get(&target_id).copied().unwrap_or(true) {
209                    Some(i)
210                } else {
211                    None
212                }
213            })
214            .collect();
215
216        if healthy_indices.is_empty() {
217            return None;
218        }
219
220        let mut rng = self.rng.write().await;
221
222        if self.config.use_weights && !self.cumulative_weights.is_empty() {
223            // Weighted random selection
224            let total_weight = self.cumulative_weights.last().copied().unwrap_or(0);
225            if total_weight > 0 {
226                let threshold = rng.gen_range(0..total_weight);
227                for &idx in &healthy_indices {
228                    if self.cumulative_weights[idx] > threshold {
229                        return Some(idx);
230                    }
231                }
232            }
233        }
234
235        // Fallback to uniform random
236        Some(healthy_indices[rng.gen_range(0..healthy_indices.len())])
237    }
238
239    /// Select the least loaded target from candidates
240    async fn select_least_loaded(&self, candidates: Vec<usize>) -> Option<usize> {
241        if candidates.is_empty() {
242            return None;
243        }
244
245        let mut min_load = f64::MAX;
246        let mut best_target = candidates[0];
247
248        for &idx in &candidates {
249            let load = self.metrics[idx].get_load(self.config.load_metric).await;
250
251            if load < min_load {
252                min_load = load;
253                best_target = idx;
254            }
255        }
256
257        debug!(
258            "P2C selected target {} with load {:.2} from {} candidates",
259            best_target,
260            min_load,
261            candidates.len()
262        );
263
264        Some(best_target)
265    }
266
267    /// Track connection acquisition
268    pub fn acquire_connection(&self, target_index: usize) {
269        self.metrics[target_index]
270            .connections
271            .fetch_add(1, Ordering::Relaxed);
272        self.metrics[target_index]
273            .requests
274            .fetch_add(1, Ordering::Relaxed);
275    }
276
277    /// Track connection release
278    pub fn release_connection(&self, target_index: usize) {
279        self.metrics[target_index]
280            .connections
281            .fetch_sub(1, Ordering::Relaxed);
282    }
283
284    /// Update target metrics
285    pub async fn update_metrics(
286        &self,
287        target_index: usize,
288        latency: Option<Duration>,
289        cpu_usage: Option<u8>,
290    ) {
291        if let Some(latency) = latency {
292            self.metrics[target_index].record_latency(latency).await;
293        }
294
295        if let Some(cpu) = cpu_usage {
296            self.metrics[target_index]
297                .cpu_usage
298                .store(cpu as u64, Ordering::Relaxed);
299        }
300
301        *self.metrics[target_index].last_update.write().await = Instant::now();
302    }
303}
304
305#[async_trait]
306impl LoadBalancer for P2cBalancer {
307    async fn select(
308        &self,
309        _context: Option<&RequestContext>,
310    ) -> SentinelResult<TargetSelection> {
311        // Select candidates
312        let num_choices = if self.config.power_of_three { 3 } else { 2 };
313        let mut candidates = Vec::with_capacity(num_choices);
314
315        for _ in 0..num_choices {
316            if let Some(idx) = self.random_healthy_target().await {
317                if !candidates.contains(&idx) {
318                    candidates.push(idx);
319                }
320            }
321        }
322
323        if candidates.is_empty() {
324            return Err(SentinelError::NoHealthyUpstream);
325        }
326
327        // Select least loaded from candidates
328        let target_index = self
329            .select_least_loaded(candidates)
330            .await
331            .ok_or(SentinelError::NoHealthyUpstream)?;
332
333        let target = &self.targets[target_index];
334
335        // Track connection
336        self.acquire_connection(target_index);
337
338        // Get current metrics for metadata
339        let current_load = self.metrics[target_index]
340            .get_load(self.config.load_metric)
341            .await;
342        let connections = self.metrics[target_index]
343            .connections
344            .load(Ordering::Relaxed);
345        let avg_latency = self.metrics[target_index].average_latency().await;
346
347        Ok(TargetSelection {
348            address: format!("{}:{}", target.address, target.port),
349            weight: target.weight,
350            metadata: {
351                let mut meta = HashMap::new();
352                meta.insert("algorithm".to_string(), "p2c".to_string());
353                meta.insert("target_index".to_string(), target_index.to_string());
354                meta.insert("current_load".to_string(), format!("{:.2}", current_load));
355                meta.insert("connections".to_string(), connections.to_string());
356                meta.insert(
357                    "avg_latency_ms".to_string(),
358                    format!("{:.2}", avg_latency.as_millis()),
359                );
360                meta.insert(
361                    "metric_type".to_string(),
362                    format!("{:?}", self.config.load_metric),
363                );
364                meta
365            },
366        })
367    }
368
369    async fn report_health(&self, address: &str, healthy: bool) {
370        let mut health = self.health_status.write().await;
371        let previous = health.insert(address.to_string(), healthy);
372
373        if previous != Some(healthy) {
374            info!(
375                "P2C: Target {} health changed from {:?} to {}",
376                address, previous, healthy
377            );
378        }
379    }
380
381    async fn healthy_targets(&self) -> Vec<String> {
382        let health = self.health_status.read().await;
383        self.targets
384            .iter()
385            .filter_map(|t| {
386                let target_id = format!("{}:{}", t.address, t.port);
387                if health.get(&target_id).copied().unwrap_or(true) {
388                    Some(target_id)
389                } else {
390                    None
391                }
392            })
393            .collect()
394    }
395
396    async fn release(&self, selection: &TargetSelection) {
397        if let Some(index_str) = selection.metadata.get("target_index") {
398            if let Ok(index) = index_str.parse::<usize>() {
399                self.release_connection(index);
400            }
401        }
402    }
403}
404
405#[cfg(test)]
406mod tests {
407    use super::*;
408
409    fn create_test_targets(count: usize) -> Vec<UpstreamTarget> {
410        (0..count)
411            .map(|i| UpstreamTarget {
412                address: format!("10.0.0.{}", i + 1),
413                port: 8080,
414                weight: 100,
415            })
416            .collect()
417    }
418
419    #[tokio::test]
420    async fn test_p2c_selection() {
421        let targets = create_test_targets(5);
422        let config = P2cConfig::default();
423        let balancer = P2cBalancer::new(targets.clone(), config);
424
425        // Simulate different loads
426        balancer.metrics[0].connections.store(10, Ordering::Relaxed);
427        balancer.metrics[1].connections.store(5, Ordering::Relaxed);
428        balancer.metrics[2].connections.store(15, Ordering::Relaxed);
429        balancer.metrics[3].connections.store(3, Ordering::Relaxed);
430        balancer.metrics[4].connections.store(8, Ordering::Relaxed);
431
432        // Run selections and verify distribution
433        let mut selections = vec![0usize; 5];
434        for _ in 0..1000 {
435            if let Ok(selection) = balancer.select(None).await {
436                if let Some(idx_str) = selection.metadata.get("target_index") {
437                    if let Ok(idx) = idx_str.parse::<usize>() {
438                        selections[idx] += 1;
439
440                        // Simulate connection release
441                        balancer.release(&selection).await;
442                    }
443                }
444            }
445        }
446
447        // Verify that lower loaded targets get more selections
448        // Target 3 (load=3) should get more than target 2 (load=15)
449        assert!(selections[3] > selections[2]);
450
451        // All targets should get some traffic
452        for count in selections {
453            assert!(count > 0, "All targets should receive some traffic");
454        }
455    }
456
457    #[tokio::test]
458    async fn test_p2c_with_latency_metric() {
459        let targets = create_test_targets(3);
460        let config = P2cConfig {
461            load_metric: LoadMetric::Latency,
462            ..Default::default()
463        };
464        let balancer = P2cBalancer::new(targets.clone(), config);
465
466        // Set different latencies
467        balancer
468            .update_metrics(0, Some(Duration::from_millis(100)), None)
469            .await;
470        balancer
471            .update_metrics(1, Some(Duration::from_millis(10)), None)
472            .await;
473        balancer
474            .update_metrics(2, Some(Duration::from_millis(50)), None)
475            .await;
476
477        let selection = balancer.select(None).await.unwrap();
478        let metadata = &selection.metadata;
479
480        // Should tend to select lower latency targets
481        assert!(metadata.contains_key("avg_latency_ms"));
482    }
483
484    #[tokio::test]
485    async fn test_p2c_power_of_three() {
486        let targets = create_test_targets(10);
487        let config = P2cConfig {
488            power_of_three: true,
489            ..Default::default()
490        };
491        let balancer = P2cBalancer::new(targets.clone(), config);
492
493        // Set varied loads
494        for i in 0..10 {
495            balancer.metrics[i]
496                .connections
497                .store((i * 2) as u64, Ordering::Relaxed);
498        }
499
500        let mut low_load_selections = 0;
501        for _ in 0..100 {
502            if let Ok(selection) = balancer.select(None).await {
503                if let Some(idx_str) = selection.metadata.get("target_index") {
504                    if let Ok(idx) = idx_str.parse::<usize>() {
505                        if idx < 3 {
506                            // Low load targets
507                            low_load_selections += 1;
508                        }
509                        balancer.release(&selection).await;
510                    }
511                }
512            }
513        }
514
515        // Power of three should give even better selection of low-load targets
516        assert!(
517            low_load_selections > 60,
518            "P3C should favor low-load targets more"
519        );
520    }
521
522    #[tokio::test]
523    async fn test_weighted_selection() {
524        let mut targets = create_test_targets(3);
525        targets[0].weight = 100;
526        targets[1].weight = 200; // Double weight
527        targets[2].weight = 100;
528
529        let config = P2cConfig {
530            use_weights: true,
531            ..Default::default()
532        };
533        let balancer = P2cBalancer::new(targets.clone(), config);
534
535        // Equal loads - weight should influence selection
536        for i in 0..3 {
537            balancer.metrics[i].connections.store(5, Ordering::Relaxed);
538        }
539
540        let mut selections = vec![0usize; 3];
541        for _ in 0..1000 {
542            if let Some(idx) = balancer.random_healthy_target().await {
543                selections[idx] += 1;
544            }
545        }
546
547        // Target 1 should get roughly twice the traffic due to weight
548        let ratio = selections[1] as f64 / selections[0] as f64;
549        assert!(
550            ratio > 1.5 && ratio < 2.5,
551            "Weighted selection not working properly"
552        );
553    }
554}