Skip to main content

sentinel_proxy/upstream/
locality.rs

1//! Locality-aware load balancer
2//!
3//! Prefers targets in the same zone/region as the proxy, falling back to
4//! other zones when local targets are unhealthy or overloaded. Useful for
5//! multi-region deployments to minimize latency and cross-zone traffic costs.
6
7use async_trait::async_trait;
8use rand::seq::IndexedRandom;
9use std::collections::HashMap;
10use std::sync::atomic::{AtomicUsize, Ordering};
11use std::sync::Arc;
12use tokio::sync::RwLock;
13use tracing::{debug, trace, warn};
14
15use sentinel_common::errors::{SentinelError, SentinelResult};
16
17use super::{LoadBalancer, RequestContext, TargetSelection, UpstreamTarget};
18
19/// Configuration for locality-aware load balancing
20#[derive(Debug, Clone)]
21pub struct LocalityAwareConfig {
22    /// The local zone/region identifier for this proxy instance
23    pub local_zone: String,
24    /// Fallback strategy when no local targets are healthy
25    pub fallback_strategy: LocalityFallback,
26    /// Minimum healthy local targets before considering fallback
27    pub min_local_healthy: usize,
28    /// Whether to use weighted selection within a zone
29    pub use_weights: bool,
30    /// Zone priority order for fallback (closest first)
31    /// If empty, all non-local zones are treated equally
32    pub zone_priority: Vec<String>,
33}
34
35impl Default for LocalityAwareConfig {
36    fn default() -> Self {
37        Self {
38            local_zone: std::env::var("SENTINEL_ZONE")
39                .or_else(|_| std::env::var("ZONE"))
40                .or_else(|_| std::env::var("REGION"))
41                .unwrap_or_else(|_| "default".to_string()),
42            fallback_strategy: LocalityFallback::RoundRobin,
43            min_local_healthy: 1,
44            use_weights: true,
45            zone_priority: Vec::new(),
46        }
47    }
48}
49
50/// Fallback strategy when local targets are unavailable
51#[derive(Debug, Clone, Copy, PartialEq, Eq)]
52pub enum LocalityFallback {
53    /// Round-robin across fallback targets
54    RoundRobin,
55    /// Random selection from fallback targets
56    Random,
57    /// Fail immediately if no local targets
58    FailLocal,
59}
60
61/// Target with zone information
62#[derive(Debug, Clone)]
63struct ZonedTarget {
64    target: UpstreamTarget,
65    zone: String,
66}
67
68/// Locality-aware load balancer
69pub struct LocalityAwareBalancer {
70    /// All targets with zone information
71    targets: Vec<ZonedTarget>,
72    /// Health status per target address
73    health_status: Arc<RwLock<HashMap<String, bool>>>,
74    /// Round-robin counter for local zone
75    local_counter: AtomicUsize,
76    /// Round-robin counter for fallback
77    fallback_counter: AtomicUsize,
78    /// Configuration
79    config: LocalityAwareConfig,
80}
81
82impl LocalityAwareBalancer {
83    /// Create a new locality-aware balancer
84    ///
85    /// Zone information is extracted from target addresses using the format:
86    /// - `zone:host:port` - explicit zone prefix
87    /// - Or via target metadata (weight field encodes zone in high bits)
88    /// - Or defaults to "unknown" zone
89    pub fn new(targets: Vec<UpstreamTarget>, config: LocalityAwareConfig) -> Self {
90        let mut health_status = HashMap::new();
91        let mut zoned_targets = Vec::with_capacity(targets.len());
92
93        for target in targets {
94            health_status.insert(target.full_address(), true);
95
96            // Extract zone from address if it contains zone prefix
97            // Format: "zone:host:port" or just "host:port"
98            let (zone, actual_target) = Self::parse_zone_from_target(&target);
99
100            zoned_targets.push(ZonedTarget {
101                target: actual_target,
102                zone,
103            });
104        }
105
106        debug!(
107            local_zone = %config.local_zone,
108            total_targets = zoned_targets.len(),
109            local_targets = zoned_targets.iter().filter(|t| t.zone == config.local_zone).count(),
110            "Created locality-aware balancer"
111        );
112
113        Self {
114            targets: zoned_targets,
115            health_status: Arc::new(RwLock::new(health_status)),
116            local_counter: AtomicUsize::new(0),
117            fallback_counter: AtomicUsize::new(0),
118            config,
119        }
120    }
121
122    /// Parse zone from target address
123    ///
124    /// Supports formats:
125    /// - `zone=us-west-1,host:port` - zone in metadata prefix
126    /// - `us-west-1/host:port` - zone as path prefix
127    /// - `host:port` - no zone, defaults to "unknown"
128    fn parse_zone_from_target(target: &UpstreamTarget) -> (String, UpstreamTarget) {
129        let addr = &target.address;
130
131        // Check for zone= prefix (e.g., "zone=us-west-1,10.0.0.1")
132        if let Some(rest) = addr.strip_prefix("zone=") {
133            if let Some((zone, host)) = rest.split_once(',') {
134                return (
135                    zone.to_string(),
136                    UpstreamTarget::new(host, target.port, target.weight),
137                );
138            }
139        }
140
141        // Check for zone/ prefix (e.g., "us-west-1/10.0.0.1")
142        if let Some((zone, host)) = addr.split_once('/') {
143            // Ensure it's not an IP with port
144            if !zone.contains(':') && !zone.contains('.') {
145                return (
146                    zone.to_string(),
147                    UpstreamTarget::new(host, target.port, target.weight),
148                );
149            }
150        }
151
152        // No zone prefix, return as-is with unknown zone
153        ("unknown".to_string(), target.clone())
154    }
155
156    /// Get healthy targets in a specific zone
157    async fn healthy_in_zone(&self, zone: &str) -> Vec<&ZonedTarget> {
158        let health = self.health_status.read().await;
159        self.targets
160            .iter()
161            .filter(|t| {
162                t.zone == zone && *health.get(&t.target.full_address()).unwrap_or(&true)
163            })
164            .collect()
165    }
166
167    /// Get all healthy targets not in the local zone, sorted by priority
168    async fn healthy_fallback(&self) -> Vec<&ZonedTarget> {
169        let health = self.health_status.read().await;
170        let local_zone = &self.config.local_zone;
171
172        let mut fallback: Vec<_> = self
173            .targets
174            .iter()
175            .filter(|t| {
176                t.zone != *local_zone && *health.get(&t.target.full_address()).unwrap_or(&true)
177            })
178            .collect();
179
180        // Sort by zone priority if specified
181        if !self.config.zone_priority.is_empty() {
182            fallback.sort_by(|a, b| {
183                let priority_a = self
184                    .config
185                    .zone_priority
186                    .iter()
187                    .position(|z| z == &a.zone)
188                    .unwrap_or(usize::MAX);
189                let priority_b = self
190                    .config
191                    .zone_priority
192                    .iter()
193                    .position(|z| z == &b.zone)
194                    .unwrap_or(usize::MAX);
195                priority_a.cmp(&priority_b)
196            });
197        }
198
199        fallback
200    }
201
202    /// Select from targets using round-robin
203    fn select_round_robin<'a>(
204        &self,
205        targets: &[&'a ZonedTarget],
206        counter: &AtomicUsize,
207    ) -> Option<&'a ZonedTarget> {
208        if targets.is_empty() {
209            return None;
210        }
211
212        if self.config.use_weights {
213            // Weighted round-robin
214            let total_weight: u32 = targets.iter().map(|t| t.target.weight).sum();
215            if total_weight == 0 {
216                return targets.first().copied();
217            }
218
219            let idx = counter.fetch_add(1, Ordering::Relaxed);
220            let mut weight_idx = (idx as u32) % total_weight;
221
222            for target in targets {
223                if weight_idx < target.target.weight {
224                    return Some(target);
225                }
226                weight_idx -= target.target.weight;
227            }
228
229            targets.first().copied()
230        } else {
231            let idx = counter.fetch_add(1, Ordering::Relaxed) % targets.len();
232            Some(targets[idx])
233        }
234    }
235
236    /// Select from targets using random selection
237    fn select_random<'a>(&self, targets: &[&'a ZonedTarget]) -> Option<&'a ZonedTarget> {
238        use rand::seq::SliceRandom;
239
240        if targets.is_empty() {
241            return None;
242        }
243
244        let mut rng = rand::rng();
245        targets.choose(&mut rng).copied()
246    }
247}
248
249#[async_trait]
250impl LoadBalancer for LocalityAwareBalancer {
251    async fn select(&self, _context: Option<&RequestContext>) -> SentinelResult<TargetSelection> {
252        trace!(
253            total_targets = self.targets.len(),
254            local_zone = %self.config.local_zone,
255            algorithm = "locality_aware",
256            "Selecting upstream target"
257        );
258
259        // First, try local zone
260        let local_healthy = self.healthy_in_zone(&self.config.local_zone).await;
261
262        if local_healthy.len() >= self.config.min_local_healthy {
263            // Use local targets
264            let selected = self
265                .select_round_robin(&local_healthy, &self.local_counter)
266                .ok_or(SentinelError::NoHealthyUpstream)?;
267
268            trace!(
269                selected_target = %selected.target.full_address(),
270                zone = %selected.zone,
271                local_healthy = local_healthy.len(),
272                algorithm = "locality_aware",
273                "Selected local target"
274            );
275
276            return Ok(TargetSelection {
277                address: selected.target.full_address(),
278                weight: selected.target.weight,
279                metadata: {
280                    let mut m = HashMap::new();
281                    m.insert("zone".to_string(), selected.zone.clone());
282                    m.insert("locality".to_string(), "local".to_string());
283                    m
284                },
285            });
286        }
287
288        // Not enough local targets, check fallback strategy
289        match self.config.fallback_strategy {
290            LocalityFallback::FailLocal => {
291                warn!(
292                    local_zone = %self.config.local_zone,
293                    local_healthy = local_healthy.len(),
294                    min_required = self.config.min_local_healthy,
295                    algorithm = "locality_aware",
296                    "No healthy local targets and fallback disabled"
297                );
298                return Err(SentinelError::NoHealthyUpstream);
299            }
300            LocalityFallback::RoundRobin | LocalityFallback::Random => {
301                // Fall back to remote zones
302            }
303        }
304
305        // Get fallback targets (sorted by zone priority)
306        let fallback_targets = self.healthy_fallback().await;
307
308        // If we have some local targets, combine them with fallback
309        let all_targets: Vec<&ZonedTarget> = if !local_healthy.is_empty() {
310            // Local first, then fallback
311            local_healthy
312                .into_iter()
313                .chain(fallback_targets.into_iter())
314                .collect()
315        } else {
316            fallback_targets
317        };
318
319        if all_targets.is_empty() {
320            warn!(
321                total_targets = self.targets.len(),
322                algorithm = "locality_aware",
323                "No healthy upstream targets available"
324            );
325            return Err(SentinelError::NoHealthyUpstream);
326        }
327
328        // Select based on fallback strategy
329        let selected = match self.config.fallback_strategy {
330            LocalityFallback::RoundRobin => {
331                self.select_round_robin(&all_targets, &self.fallback_counter)
332            }
333            LocalityFallback::Random => self.select_random(&all_targets),
334            LocalityFallback::FailLocal => unreachable!(),
335        }
336        .ok_or(SentinelError::NoHealthyUpstream)?;
337
338        let is_local = selected.zone == self.config.local_zone;
339        debug!(
340            selected_target = %selected.target.full_address(),
341            zone = %selected.zone,
342            is_local = is_local,
343            fallback_used = !is_local,
344            algorithm = "locality_aware",
345            "Selected target (fallback path)"
346        );
347
348        Ok(TargetSelection {
349            address: selected.target.full_address(),
350            weight: selected.target.weight,
351            metadata: {
352                let mut m = HashMap::new();
353                m.insert("zone".to_string(), selected.zone.clone());
354                m.insert(
355                    "locality".to_string(),
356                    if is_local { "local" } else { "remote" }.to_string(),
357                );
358                m
359            },
360        })
361    }
362
363    async fn report_health(&self, address: &str, healthy: bool) {
364        trace!(
365            target = %address,
366            healthy = healthy,
367            algorithm = "locality_aware",
368            "Updating target health status"
369        );
370        self.health_status
371            .write()
372            .await
373            .insert(address.to_string(), healthy);
374    }
375
376    async fn healthy_targets(&self) -> Vec<String> {
377        self.health_status
378            .read()
379            .await
380            .iter()
381            .filter_map(|(addr, &healthy)| if healthy { Some(addr.clone()) } else { None })
382            .collect()
383    }
384}
385
386#[cfg(test)]
387mod tests {
388    use super::*;
389
390    fn make_zoned_targets() -> Vec<UpstreamTarget> {
391        vec![
392            // Local zone (us-west-1)
393            UpstreamTarget::new("zone=us-west-1,10.0.0.1", 8080, 100),
394            UpstreamTarget::new("zone=us-west-1,10.0.0.2", 8080, 100),
395            // Remote zone (us-east-1)
396            UpstreamTarget::new("zone=us-east-1,10.1.0.1", 8080, 100),
397            UpstreamTarget::new("zone=us-east-1,10.1.0.2", 8080, 100),
398            // Another remote zone (eu-west-1)
399            UpstreamTarget::new("zone=eu-west-1,10.2.0.1", 8080, 100),
400        ]
401    }
402
403    #[test]
404    fn test_zone_parsing() {
405        // Test zone= prefix
406        let target = UpstreamTarget::new("zone=us-west-1,10.0.0.1", 8080, 100);
407        let (zone, parsed) = LocalityAwareBalancer::parse_zone_from_target(&target);
408        assert_eq!(zone, "us-west-1");
409        assert_eq!(parsed.address, "10.0.0.1");
410
411        // Test zone/ prefix
412        let target = UpstreamTarget::new("us-east-1/10.0.0.1", 8080, 100);
413        let (zone, parsed) = LocalityAwareBalancer::parse_zone_from_target(&target);
414        assert_eq!(zone, "us-east-1");
415        assert_eq!(parsed.address, "10.0.0.1");
416
417        // Test no zone
418        let target = UpstreamTarget::new("10.0.0.1", 8080, 100);
419        let (zone, parsed) = LocalityAwareBalancer::parse_zone_from_target(&target);
420        assert_eq!(zone, "unknown");
421        assert_eq!(parsed.address, "10.0.0.1");
422    }
423
424    #[tokio::test]
425    async fn test_prefers_local_zone() {
426        let targets = make_zoned_targets();
427        let config = LocalityAwareConfig {
428            local_zone: "us-west-1".to_string(),
429            ..Default::default()
430        };
431        let balancer = LocalityAwareBalancer::new(targets, config);
432
433        // All selections should be from local zone
434        for _ in 0..10 {
435            let selection = balancer.select(None).await.unwrap();
436            assert!(
437                selection.address.starts_with("10.0.0."),
438                "Expected local target, got {}",
439                selection.address
440            );
441            assert_eq!(selection.metadata.get("locality").unwrap(), "local");
442        }
443    }
444
445    #[tokio::test]
446    async fn test_fallback_when_local_unhealthy() {
447        let targets = make_zoned_targets();
448        let config = LocalityAwareConfig {
449            local_zone: "us-west-1".to_string(),
450            min_local_healthy: 1,
451            ..Default::default()
452        };
453        let balancer = LocalityAwareBalancer::new(targets, config);
454
455        // Mark local targets as unhealthy
456        balancer.report_health("10.0.0.1:8080", false).await;
457        balancer.report_health("10.0.0.2:8080", false).await;
458
459        // Should now use fallback targets
460        let selection = balancer.select(None).await.unwrap();
461        assert!(
462            !selection.address.starts_with("10.0.0."),
463            "Expected fallback target, got {}",
464            selection.address
465        );
466        assert_eq!(selection.metadata.get("locality").unwrap(), "remote");
467    }
468
469    #[tokio::test]
470    async fn test_zone_priority() {
471        let targets = make_zoned_targets();
472        let config = LocalityAwareConfig {
473            local_zone: "us-west-1".to_string(),
474            min_local_healthy: 1,
475            zone_priority: vec!["us-east-1".to_string(), "eu-west-1".to_string()],
476            ..Default::default()
477        };
478        let balancer = LocalityAwareBalancer::new(targets, config);
479
480        // Mark local targets as unhealthy
481        balancer.report_health("10.0.0.1:8080", false).await;
482        balancer.report_health("10.0.0.2:8080", false).await;
483
484        // Should prefer us-east-1 over eu-west-1
485        let selection = balancer.select(None).await.unwrap();
486        assert!(
487            selection.address.starts_with("10.1.0."),
488            "Expected us-east-1 target, got {}",
489            selection.address
490        );
491    }
492
493    #[tokio::test]
494    async fn test_fail_local_strategy() {
495        let targets = make_zoned_targets();
496        let config = LocalityAwareConfig {
497            local_zone: "us-west-1".to_string(),
498            fallback_strategy: LocalityFallback::FailLocal,
499            ..Default::default()
500        };
501        let balancer = LocalityAwareBalancer::new(targets, config);
502
503        // Mark local targets as unhealthy
504        balancer.report_health("10.0.0.1:8080", false).await;
505        balancer.report_health("10.0.0.2:8080", false).await;
506
507        // Should fail
508        let result = balancer.select(None).await;
509        assert!(result.is_err());
510    }
511}