sentinel_proxy/upstream/
locality.rs

1//! Locality-aware load balancer
2//!
3//! Prefers targets in the same zone/region as the proxy, falling back to
4//! other zones when local targets are unhealthy or overloaded. Useful for
5//! multi-region deployments to minimize latency and cross-zone traffic costs.
6
7use async_trait::async_trait;
8use std::collections::HashMap;
9use std::sync::atomic::{AtomicUsize, Ordering};
10use std::sync::Arc;
11use tokio::sync::RwLock;
12use tracing::{debug, trace, warn};
13
14use sentinel_common::errors::{SentinelError, SentinelResult};
15
16use super::{LoadBalancer, RequestContext, TargetSelection, UpstreamTarget};
17
18/// Configuration for locality-aware load balancing
19#[derive(Debug, Clone)]
20pub struct LocalityAwareConfig {
21    /// The local zone/region identifier for this proxy instance
22    pub local_zone: String,
23    /// Fallback strategy when no local targets are healthy
24    pub fallback_strategy: LocalityFallback,
25    /// Minimum healthy local targets before considering fallback
26    pub min_local_healthy: usize,
27    /// Whether to use weighted selection within a zone
28    pub use_weights: bool,
29    /// Zone priority order for fallback (closest first)
30    /// If empty, all non-local zones are treated equally
31    pub zone_priority: Vec<String>,
32}
33
34impl Default for LocalityAwareConfig {
35    fn default() -> Self {
36        Self {
37            local_zone: std::env::var("SENTINEL_ZONE")
38                .or_else(|_| std::env::var("ZONE"))
39                .or_else(|_| std::env::var("REGION"))
40                .unwrap_or_else(|_| "default".to_string()),
41            fallback_strategy: LocalityFallback::RoundRobin,
42            min_local_healthy: 1,
43            use_weights: true,
44            zone_priority: Vec::new(),
45        }
46    }
47}
48
49/// Fallback strategy when local targets are unavailable
50#[derive(Debug, Clone, Copy, PartialEq, Eq)]
51pub enum LocalityFallback {
52    /// Round-robin across fallback targets
53    RoundRobin,
54    /// Random selection from fallback targets
55    Random,
56    /// Fail immediately if no local targets
57    FailLocal,
58}
59
60/// Target with zone information
61#[derive(Debug, Clone)]
62struct ZonedTarget {
63    target: UpstreamTarget,
64    zone: String,
65}
66
67/// Locality-aware load balancer
68pub struct LocalityAwareBalancer {
69    /// All targets with zone information
70    targets: Vec<ZonedTarget>,
71    /// Health status per target address
72    health_status: Arc<RwLock<HashMap<String, bool>>>,
73    /// Round-robin counter for local zone
74    local_counter: AtomicUsize,
75    /// Round-robin counter for fallback
76    fallback_counter: AtomicUsize,
77    /// Configuration
78    config: LocalityAwareConfig,
79}
80
81impl LocalityAwareBalancer {
82    /// Create a new locality-aware balancer
83    ///
84    /// Zone information is extracted from target addresses using the format:
85    /// - `zone:host:port` - explicit zone prefix
86    /// - Or via target metadata (weight field encodes zone in high bits)
87    /// - Or defaults to "unknown" zone
88    pub fn new(targets: Vec<UpstreamTarget>, config: LocalityAwareConfig) -> Self {
89        let mut health_status = HashMap::new();
90        let mut zoned_targets = Vec::with_capacity(targets.len());
91
92        for target in targets {
93            health_status.insert(target.full_address(), true);
94
95            // Extract zone from address if it contains zone prefix
96            // Format: "zone:host:port" or just "host:port"
97            let (zone, actual_target) = Self::parse_zone_from_target(&target);
98
99            zoned_targets.push(ZonedTarget {
100                target: actual_target,
101                zone,
102            });
103        }
104
105        debug!(
106            local_zone = %config.local_zone,
107            total_targets = zoned_targets.len(),
108            local_targets = zoned_targets.iter().filter(|t| t.zone == config.local_zone).count(),
109            "Created locality-aware balancer"
110        );
111
112        Self {
113            targets: zoned_targets,
114            health_status: Arc::new(RwLock::new(health_status)),
115            local_counter: AtomicUsize::new(0),
116            fallback_counter: AtomicUsize::new(0),
117            config,
118        }
119    }
120
121    /// Parse zone from target address
122    ///
123    /// Supports formats:
124    /// - `zone=us-west-1,host:port` - zone in metadata prefix
125    /// - `us-west-1/host:port` - zone as path prefix
126    /// - `host:port` - no zone, defaults to "unknown"
127    fn parse_zone_from_target(target: &UpstreamTarget) -> (String, UpstreamTarget) {
128        let addr = &target.address;
129
130        // Check for zone= prefix (e.g., "zone=us-west-1,10.0.0.1")
131        if let Some(rest) = addr.strip_prefix("zone=") {
132            if let Some((zone, host)) = rest.split_once(',') {
133                return (
134                    zone.to_string(),
135                    UpstreamTarget::new(host, target.port, target.weight),
136                );
137            }
138        }
139
140        // Check for zone/ prefix (e.g., "us-west-1/10.0.0.1")
141        if let Some((zone, host)) = addr.split_once('/') {
142            // Ensure it's not an IP with port
143            if !zone.contains(':') && !zone.contains('.') {
144                return (
145                    zone.to_string(),
146                    UpstreamTarget::new(host, target.port, target.weight),
147                );
148            }
149        }
150
151        // No zone prefix, return as-is with unknown zone
152        ("unknown".to_string(), target.clone())
153    }
154
155    /// Get healthy targets in a specific zone
156    async fn healthy_in_zone(&self, zone: &str) -> Vec<&ZonedTarget> {
157        let health = self.health_status.read().await;
158        self.targets
159            .iter()
160            .filter(|t| {
161                t.zone == zone && *health.get(&t.target.full_address()).unwrap_or(&true)
162            })
163            .collect()
164    }
165
166    /// Get all healthy targets not in the local zone, sorted by priority
167    async fn healthy_fallback(&self) -> Vec<&ZonedTarget> {
168        let health = self.health_status.read().await;
169        let local_zone = &self.config.local_zone;
170
171        let mut fallback: Vec<_> = self
172            .targets
173            .iter()
174            .filter(|t| {
175                t.zone != *local_zone && *health.get(&t.target.full_address()).unwrap_or(&true)
176            })
177            .collect();
178
179        // Sort by zone priority if specified
180        if !self.config.zone_priority.is_empty() {
181            fallback.sort_by(|a, b| {
182                let priority_a = self
183                    .config
184                    .zone_priority
185                    .iter()
186                    .position(|z| z == &a.zone)
187                    .unwrap_or(usize::MAX);
188                let priority_b = self
189                    .config
190                    .zone_priority
191                    .iter()
192                    .position(|z| z == &b.zone)
193                    .unwrap_or(usize::MAX);
194                priority_a.cmp(&priority_b)
195            });
196        }
197
198        fallback
199    }
200
201    /// Select from targets using round-robin
202    fn select_round_robin<'a>(
203        &self,
204        targets: &[&'a ZonedTarget],
205        counter: &AtomicUsize,
206    ) -> Option<&'a ZonedTarget> {
207        if targets.is_empty() {
208            return None;
209        }
210
211        if self.config.use_weights {
212            // Weighted round-robin
213            let total_weight: u32 = targets.iter().map(|t| t.target.weight).sum();
214            if total_weight == 0 {
215                return targets.first().copied();
216            }
217
218            let idx = counter.fetch_add(1, Ordering::Relaxed);
219            let mut weight_idx = (idx as u32) % total_weight;
220
221            for target in targets {
222                if weight_idx < target.target.weight {
223                    return Some(target);
224                }
225                weight_idx -= target.target.weight;
226            }
227
228            targets.first().copied()
229        } else {
230            let idx = counter.fetch_add(1, Ordering::Relaxed) % targets.len();
231            Some(targets[idx])
232        }
233    }
234
235    /// Select from targets using random selection
236    fn select_random<'a>(&self, targets: &[&'a ZonedTarget]) -> Option<&'a ZonedTarget> {
237        use rand::seq::SliceRandom;
238
239        if targets.is_empty() {
240            return None;
241        }
242
243        let mut rng = rand::thread_rng();
244        targets.choose(&mut rng).copied()
245    }
246}
247
248#[async_trait]
249impl LoadBalancer for LocalityAwareBalancer {
250    async fn select(&self, _context: Option<&RequestContext>) -> SentinelResult<TargetSelection> {
251        trace!(
252            total_targets = self.targets.len(),
253            local_zone = %self.config.local_zone,
254            algorithm = "locality_aware",
255            "Selecting upstream target"
256        );
257
258        // First, try local zone
259        let local_healthy = self.healthy_in_zone(&self.config.local_zone).await;
260
261        if local_healthy.len() >= self.config.min_local_healthy {
262            // Use local targets
263            let selected = self
264                .select_round_robin(&local_healthy, &self.local_counter)
265                .ok_or(SentinelError::NoHealthyUpstream)?;
266
267            trace!(
268                selected_target = %selected.target.full_address(),
269                zone = %selected.zone,
270                local_healthy = local_healthy.len(),
271                algorithm = "locality_aware",
272                "Selected local target"
273            );
274
275            return Ok(TargetSelection {
276                address: selected.target.full_address(),
277                weight: selected.target.weight,
278                metadata: {
279                    let mut m = HashMap::new();
280                    m.insert("zone".to_string(), selected.zone.clone());
281                    m.insert("locality".to_string(), "local".to_string());
282                    m
283                },
284            });
285        }
286
287        // Not enough local targets, check fallback strategy
288        match self.config.fallback_strategy {
289            LocalityFallback::FailLocal => {
290                warn!(
291                    local_zone = %self.config.local_zone,
292                    local_healthy = local_healthy.len(),
293                    min_required = self.config.min_local_healthy,
294                    algorithm = "locality_aware",
295                    "No healthy local targets and fallback disabled"
296                );
297                return Err(SentinelError::NoHealthyUpstream);
298            }
299            LocalityFallback::RoundRobin | LocalityFallback::Random => {
300                // Fall back to remote zones
301            }
302        }
303
304        // Get fallback targets (sorted by zone priority)
305        let fallback_targets = self.healthy_fallback().await;
306
307        // If we have some local targets, combine them with fallback
308        let all_targets: Vec<&ZonedTarget> = if !local_healthy.is_empty() {
309            // Local first, then fallback
310            local_healthy
311                .into_iter()
312                .chain(fallback_targets.into_iter())
313                .collect()
314        } else {
315            fallback_targets
316        };
317
318        if all_targets.is_empty() {
319            warn!(
320                total_targets = self.targets.len(),
321                algorithm = "locality_aware",
322                "No healthy upstream targets available"
323            );
324            return Err(SentinelError::NoHealthyUpstream);
325        }
326
327        // Select based on fallback strategy
328        let selected = match self.config.fallback_strategy {
329            LocalityFallback::RoundRobin => {
330                self.select_round_robin(&all_targets, &self.fallback_counter)
331            }
332            LocalityFallback::Random => self.select_random(&all_targets),
333            LocalityFallback::FailLocal => unreachable!(),
334        }
335        .ok_or(SentinelError::NoHealthyUpstream)?;
336
337        let is_local = selected.zone == self.config.local_zone;
338        debug!(
339            selected_target = %selected.target.full_address(),
340            zone = %selected.zone,
341            is_local = is_local,
342            fallback_used = !is_local,
343            algorithm = "locality_aware",
344            "Selected target (fallback path)"
345        );
346
347        Ok(TargetSelection {
348            address: selected.target.full_address(),
349            weight: selected.target.weight,
350            metadata: {
351                let mut m = HashMap::new();
352                m.insert("zone".to_string(), selected.zone.clone());
353                m.insert(
354                    "locality".to_string(),
355                    if is_local { "local" } else { "remote" }.to_string(),
356                );
357                m
358            },
359        })
360    }
361
362    async fn report_health(&self, address: &str, healthy: bool) {
363        trace!(
364            target = %address,
365            healthy = healthy,
366            algorithm = "locality_aware",
367            "Updating target health status"
368        );
369        self.health_status
370            .write()
371            .await
372            .insert(address.to_string(), healthy);
373    }
374
375    async fn healthy_targets(&self) -> Vec<String> {
376        self.health_status
377            .read()
378            .await
379            .iter()
380            .filter_map(|(addr, &healthy)| if healthy { Some(addr.clone()) } else { None })
381            .collect()
382    }
383}
384
385#[cfg(test)]
386mod tests {
387    use super::*;
388
389    fn make_zoned_targets() -> Vec<UpstreamTarget> {
390        vec![
391            // Local zone (us-west-1)
392            UpstreamTarget::new("zone=us-west-1,10.0.0.1", 8080, 100),
393            UpstreamTarget::new("zone=us-west-1,10.0.0.2", 8080, 100),
394            // Remote zone (us-east-1)
395            UpstreamTarget::new("zone=us-east-1,10.1.0.1", 8080, 100),
396            UpstreamTarget::new("zone=us-east-1,10.1.0.2", 8080, 100),
397            // Another remote zone (eu-west-1)
398            UpstreamTarget::new("zone=eu-west-1,10.2.0.1", 8080, 100),
399        ]
400    }
401
402    #[test]
403    fn test_zone_parsing() {
404        // Test zone= prefix
405        let target = UpstreamTarget::new("zone=us-west-1,10.0.0.1", 8080, 100);
406        let (zone, parsed) = LocalityAwareBalancer::parse_zone_from_target(&target);
407        assert_eq!(zone, "us-west-1");
408        assert_eq!(parsed.address, "10.0.0.1");
409
410        // Test zone/ prefix
411        let target = UpstreamTarget::new("us-east-1/10.0.0.1", 8080, 100);
412        let (zone, parsed) = LocalityAwareBalancer::parse_zone_from_target(&target);
413        assert_eq!(zone, "us-east-1");
414        assert_eq!(parsed.address, "10.0.0.1");
415
416        // Test no zone
417        let target = UpstreamTarget::new("10.0.0.1", 8080, 100);
418        let (zone, parsed) = LocalityAwareBalancer::parse_zone_from_target(&target);
419        assert_eq!(zone, "unknown");
420        assert_eq!(parsed.address, "10.0.0.1");
421    }
422
423    #[tokio::test]
424    async fn test_prefers_local_zone() {
425        let targets = make_zoned_targets();
426        let config = LocalityAwareConfig {
427            local_zone: "us-west-1".to_string(),
428            ..Default::default()
429        };
430        let balancer = LocalityAwareBalancer::new(targets, config);
431
432        // All selections should be from local zone
433        for _ in 0..10 {
434            let selection = balancer.select(None).await.unwrap();
435            assert!(
436                selection.address.starts_with("10.0.0."),
437                "Expected local target, got {}",
438                selection.address
439            );
440            assert_eq!(selection.metadata.get("locality").unwrap(), "local");
441        }
442    }
443
444    #[tokio::test]
445    async fn test_fallback_when_local_unhealthy() {
446        let targets = make_zoned_targets();
447        let config = LocalityAwareConfig {
448            local_zone: "us-west-1".to_string(),
449            min_local_healthy: 1,
450            ..Default::default()
451        };
452        let balancer = LocalityAwareBalancer::new(targets, config);
453
454        // Mark local targets as unhealthy
455        balancer.report_health("10.0.0.1:8080", false).await;
456        balancer.report_health("10.0.0.2:8080", false).await;
457
458        // Should now use fallback targets
459        let selection = balancer.select(None).await.unwrap();
460        assert!(
461            !selection.address.starts_with("10.0.0."),
462            "Expected fallback target, got {}",
463            selection.address
464        );
465        assert_eq!(selection.metadata.get("locality").unwrap(), "remote");
466    }
467
468    #[tokio::test]
469    async fn test_zone_priority() {
470        let targets = make_zoned_targets();
471        let config = LocalityAwareConfig {
472            local_zone: "us-west-1".to_string(),
473            min_local_healthy: 1,
474            zone_priority: vec!["us-east-1".to_string(), "eu-west-1".to_string()],
475            ..Default::default()
476        };
477        let balancer = LocalityAwareBalancer::new(targets, config);
478
479        // Mark local targets as unhealthy
480        balancer.report_health("10.0.0.1:8080", false).await;
481        balancer.report_health("10.0.0.2:8080", false).await;
482
483        // Should prefer us-east-1 over eu-west-1
484        let selection = balancer.select(None).await.unwrap();
485        assert!(
486            selection.address.starts_with("10.1.0."),
487            "Expected us-east-1 target, got {}",
488            selection.address
489        );
490    }
491
492    #[tokio::test]
493    async fn test_fail_local_strategy() {
494        let targets = make_zoned_targets();
495        let config = LocalityAwareConfig {
496            local_zone: "us-west-1".to_string(),
497            fallback_strategy: LocalityFallback::FailLocal,
498            ..Default::default()
499        };
500        let balancer = LocalityAwareBalancer::new(targets, config);
501
502        // Mark local targets as unhealthy
503        balancer.report_health("10.0.0.1:8080", false).await;
504        balancer.report_health("10.0.0.2:8080", false).await;
505
506        // Should fail
507        let result = balancer.select(None).await;
508        assert!(result.is_err());
509    }
510}