1use async_trait::async_trait;
8use rand::seq::IndexedRandom;
9use std::collections::HashMap;
10use std::sync::atomic::{AtomicUsize, Ordering};
11use std::sync::Arc;
12use tokio::sync::RwLock;
13use tracing::{debug, trace, warn};
14
15use sentinel_common::errors::{SentinelError, SentinelResult};
16
17use super::{LoadBalancer, RequestContext, TargetSelection, UpstreamTarget};
18
19#[derive(Debug, Clone)]
21pub struct LocalityAwareConfig {
22 pub local_zone: String,
24 pub fallback_strategy: LocalityFallback,
26 pub min_local_healthy: usize,
28 pub use_weights: bool,
30 pub zone_priority: Vec<String>,
33}
34
35impl Default for LocalityAwareConfig {
36 fn default() -> Self {
37 Self {
38 local_zone: std::env::var("SENTINEL_ZONE")
39 .or_else(|_| std::env::var("ZONE"))
40 .or_else(|_| std::env::var("REGION"))
41 .unwrap_or_else(|_| "default".to_string()),
42 fallback_strategy: LocalityFallback::RoundRobin,
43 min_local_healthy: 1,
44 use_weights: true,
45 zone_priority: Vec::new(),
46 }
47 }
48}
49
50#[derive(Debug, Clone, Copy, PartialEq, Eq)]
52pub enum LocalityFallback {
53 RoundRobin,
55 Random,
57 FailLocal,
59}
60
61#[derive(Debug, Clone)]
63struct ZonedTarget {
64 target: UpstreamTarget,
65 zone: String,
66}
67
68pub struct LocalityAwareBalancer {
70 targets: Vec<ZonedTarget>,
72 health_status: Arc<RwLock<HashMap<String, bool>>>,
74 local_counter: AtomicUsize,
76 fallback_counter: AtomicUsize,
78 config: LocalityAwareConfig,
80}
81
82impl LocalityAwareBalancer {
83 pub fn new(targets: Vec<UpstreamTarget>, config: LocalityAwareConfig) -> Self {
90 let mut health_status = HashMap::new();
91 let mut zoned_targets = Vec::with_capacity(targets.len());
92
93 for target in targets {
94 health_status.insert(target.full_address(), true);
95
96 let (zone, actual_target) = Self::parse_zone_from_target(&target);
99
100 zoned_targets.push(ZonedTarget {
101 target: actual_target,
102 zone,
103 });
104 }
105
106 debug!(
107 local_zone = %config.local_zone,
108 total_targets = zoned_targets.len(),
109 local_targets = zoned_targets.iter().filter(|t| t.zone == config.local_zone).count(),
110 "Created locality-aware balancer"
111 );
112
113 Self {
114 targets: zoned_targets,
115 health_status: Arc::new(RwLock::new(health_status)),
116 local_counter: AtomicUsize::new(0),
117 fallback_counter: AtomicUsize::new(0),
118 config,
119 }
120 }
121
122 fn parse_zone_from_target(target: &UpstreamTarget) -> (String, UpstreamTarget) {
129 let addr = &target.address;
130
131 if let Some(rest) = addr.strip_prefix("zone=") {
133 if let Some((zone, host)) = rest.split_once(',') {
134 return (
135 zone.to_string(),
136 UpstreamTarget::new(host, target.port, target.weight),
137 );
138 }
139 }
140
141 if let Some((zone, host)) = addr.split_once('/') {
143 if !zone.contains(':') && !zone.contains('.') {
145 return (
146 zone.to_string(),
147 UpstreamTarget::new(host, target.port, target.weight),
148 );
149 }
150 }
151
152 ("unknown".to_string(), target.clone())
154 }
155
156 async fn healthy_in_zone(&self, zone: &str) -> Vec<&ZonedTarget> {
158 let health = self.health_status.read().await;
159 self.targets
160 .iter()
161 .filter(|t| {
162 t.zone == zone && *health.get(&t.target.full_address()).unwrap_or(&true)
163 })
164 .collect()
165 }
166
167 async fn healthy_fallback(&self) -> Vec<&ZonedTarget> {
169 let health = self.health_status.read().await;
170 let local_zone = &self.config.local_zone;
171
172 let mut fallback: Vec<_> = self
173 .targets
174 .iter()
175 .filter(|t| {
176 t.zone != *local_zone && *health.get(&t.target.full_address()).unwrap_or(&true)
177 })
178 .collect();
179
180 if !self.config.zone_priority.is_empty() {
182 fallback.sort_by(|a, b| {
183 let priority_a = self
184 .config
185 .zone_priority
186 .iter()
187 .position(|z| z == &a.zone)
188 .unwrap_or(usize::MAX);
189 let priority_b = self
190 .config
191 .zone_priority
192 .iter()
193 .position(|z| z == &b.zone)
194 .unwrap_or(usize::MAX);
195 priority_a.cmp(&priority_b)
196 });
197 }
198
199 fallback
200 }
201
202 fn select_round_robin<'a>(
204 &self,
205 targets: &[&'a ZonedTarget],
206 counter: &AtomicUsize,
207 ) -> Option<&'a ZonedTarget> {
208 if targets.is_empty() {
209 return None;
210 }
211
212 if self.config.use_weights {
213 let total_weight: u32 = targets.iter().map(|t| t.target.weight).sum();
215 if total_weight == 0 {
216 return targets.first().copied();
217 }
218
219 let idx = counter.fetch_add(1, Ordering::Relaxed);
220 let mut weight_idx = (idx as u32) % total_weight;
221
222 for target in targets {
223 if weight_idx < target.target.weight {
224 return Some(target);
225 }
226 weight_idx -= target.target.weight;
227 }
228
229 targets.first().copied()
230 } else {
231 let idx = counter.fetch_add(1, Ordering::Relaxed) % targets.len();
232 Some(targets[idx])
233 }
234 }
235
236 fn select_random<'a>(&self, targets: &[&'a ZonedTarget]) -> Option<&'a ZonedTarget> {
238 use rand::seq::SliceRandom;
239
240 if targets.is_empty() {
241 return None;
242 }
243
244 let mut rng = rand::rng();
245 targets.choose(&mut rng).copied()
246 }
247}
248
249#[async_trait]
250impl LoadBalancer for LocalityAwareBalancer {
251 async fn select(&self, _context: Option<&RequestContext>) -> SentinelResult<TargetSelection> {
252 trace!(
253 total_targets = self.targets.len(),
254 local_zone = %self.config.local_zone,
255 algorithm = "locality_aware",
256 "Selecting upstream target"
257 );
258
259 let local_healthy = self.healthy_in_zone(&self.config.local_zone).await;
261
262 if local_healthy.len() >= self.config.min_local_healthy {
263 let selected = self
265 .select_round_robin(&local_healthy, &self.local_counter)
266 .ok_or(SentinelError::NoHealthyUpstream)?;
267
268 trace!(
269 selected_target = %selected.target.full_address(),
270 zone = %selected.zone,
271 local_healthy = local_healthy.len(),
272 algorithm = "locality_aware",
273 "Selected local target"
274 );
275
276 return Ok(TargetSelection {
277 address: selected.target.full_address(),
278 weight: selected.target.weight,
279 metadata: {
280 let mut m = HashMap::new();
281 m.insert("zone".to_string(), selected.zone.clone());
282 m.insert("locality".to_string(), "local".to_string());
283 m
284 },
285 });
286 }
287
288 match self.config.fallback_strategy {
290 LocalityFallback::FailLocal => {
291 warn!(
292 local_zone = %self.config.local_zone,
293 local_healthy = local_healthy.len(),
294 min_required = self.config.min_local_healthy,
295 algorithm = "locality_aware",
296 "No healthy local targets and fallback disabled"
297 );
298 return Err(SentinelError::NoHealthyUpstream);
299 }
300 LocalityFallback::RoundRobin | LocalityFallback::Random => {
301 }
303 }
304
305 let fallback_targets = self.healthy_fallback().await;
307
308 let all_targets: Vec<&ZonedTarget> = if !local_healthy.is_empty() {
310 local_healthy
312 .into_iter()
313 .chain(fallback_targets.into_iter())
314 .collect()
315 } else {
316 fallback_targets
317 };
318
319 if all_targets.is_empty() {
320 warn!(
321 total_targets = self.targets.len(),
322 algorithm = "locality_aware",
323 "No healthy upstream targets available"
324 );
325 return Err(SentinelError::NoHealthyUpstream);
326 }
327
328 let selected = match self.config.fallback_strategy {
330 LocalityFallback::RoundRobin => {
331 self.select_round_robin(&all_targets, &self.fallback_counter)
332 }
333 LocalityFallback::Random => self.select_random(&all_targets),
334 LocalityFallback::FailLocal => unreachable!(),
335 }
336 .ok_or(SentinelError::NoHealthyUpstream)?;
337
338 let is_local = selected.zone == self.config.local_zone;
339 debug!(
340 selected_target = %selected.target.full_address(),
341 zone = %selected.zone,
342 is_local = is_local,
343 fallback_used = !is_local,
344 algorithm = "locality_aware",
345 "Selected target (fallback path)"
346 );
347
348 Ok(TargetSelection {
349 address: selected.target.full_address(),
350 weight: selected.target.weight,
351 metadata: {
352 let mut m = HashMap::new();
353 m.insert("zone".to_string(), selected.zone.clone());
354 m.insert(
355 "locality".to_string(),
356 if is_local { "local" } else { "remote" }.to_string(),
357 );
358 m
359 },
360 })
361 }
362
363 async fn report_health(&self, address: &str, healthy: bool) {
364 trace!(
365 target = %address,
366 healthy = healthy,
367 algorithm = "locality_aware",
368 "Updating target health status"
369 );
370 self.health_status
371 .write()
372 .await
373 .insert(address.to_string(), healthy);
374 }
375
376 async fn healthy_targets(&self) -> Vec<String> {
377 self.health_status
378 .read()
379 .await
380 .iter()
381 .filter_map(|(addr, &healthy)| if healthy { Some(addr.clone()) } else { None })
382 .collect()
383 }
384}
385
386#[cfg(test)]
387mod tests {
388 use super::*;
389
390 fn make_zoned_targets() -> Vec<UpstreamTarget> {
391 vec![
392 UpstreamTarget::new("zone=us-west-1,10.0.0.1", 8080, 100),
394 UpstreamTarget::new("zone=us-west-1,10.0.0.2", 8080, 100),
395 UpstreamTarget::new("zone=us-east-1,10.1.0.1", 8080, 100),
397 UpstreamTarget::new("zone=us-east-1,10.1.0.2", 8080, 100),
398 UpstreamTarget::new("zone=eu-west-1,10.2.0.1", 8080, 100),
400 ]
401 }
402
403 #[test]
404 fn test_zone_parsing() {
405 let target = UpstreamTarget::new("zone=us-west-1,10.0.0.1", 8080, 100);
407 let (zone, parsed) = LocalityAwareBalancer::parse_zone_from_target(&target);
408 assert_eq!(zone, "us-west-1");
409 assert_eq!(parsed.address, "10.0.0.1");
410
411 let target = UpstreamTarget::new("us-east-1/10.0.0.1", 8080, 100);
413 let (zone, parsed) = LocalityAwareBalancer::parse_zone_from_target(&target);
414 assert_eq!(zone, "us-east-1");
415 assert_eq!(parsed.address, "10.0.0.1");
416
417 let target = UpstreamTarget::new("10.0.0.1", 8080, 100);
419 let (zone, parsed) = LocalityAwareBalancer::parse_zone_from_target(&target);
420 assert_eq!(zone, "unknown");
421 assert_eq!(parsed.address, "10.0.0.1");
422 }
423
424 #[tokio::test]
425 async fn test_prefers_local_zone() {
426 let targets = make_zoned_targets();
427 let config = LocalityAwareConfig {
428 local_zone: "us-west-1".to_string(),
429 ..Default::default()
430 };
431 let balancer = LocalityAwareBalancer::new(targets, config);
432
433 for _ in 0..10 {
435 let selection = balancer.select(None).await.unwrap();
436 assert!(
437 selection.address.starts_with("10.0.0."),
438 "Expected local target, got {}",
439 selection.address
440 );
441 assert_eq!(selection.metadata.get("locality").unwrap(), "local");
442 }
443 }
444
445 #[tokio::test]
446 async fn test_fallback_when_local_unhealthy() {
447 let targets = make_zoned_targets();
448 let config = LocalityAwareConfig {
449 local_zone: "us-west-1".to_string(),
450 min_local_healthy: 1,
451 ..Default::default()
452 };
453 let balancer = LocalityAwareBalancer::new(targets, config);
454
455 balancer.report_health("10.0.0.1:8080", false).await;
457 balancer.report_health("10.0.0.2:8080", false).await;
458
459 let selection = balancer.select(None).await.unwrap();
461 assert!(
462 !selection.address.starts_with("10.0.0."),
463 "Expected fallback target, got {}",
464 selection.address
465 );
466 assert_eq!(selection.metadata.get("locality").unwrap(), "remote");
467 }
468
469 #[tokio::test]
470 async fn test_zone_priority() {
471 let targets = make_zoned_targets();
472 let config = LocalityAwareConfig {
473 local_zone: "us-west-1".to_string(),
474 min_local_healthy: 1,
475 zone_priority: vec!["us-east-1".to_string(), "eu-west-1".to_string()],
476 ..Default::default()
477 };
478 let balancer = LocalityAwareBalancer::new(targets, config);
479
480 balancer.report_health("10.0.0.1:8080", false).await;
482 balancer.report_health("10.0.0.2:8080", false).await;
483
484 let selection = balancer.select(None).await.unwrap();
486 assert!(
487 selection.address.starts_with("10.1.0."),
488 "Expected us-east-1 target, got {}",
489 selection.address
490 );
491 }
492
493 #[tokio::test]
494 async fn test_fail_local_strategy() {
495 let targets = make_zoned_targets();
496 let config = LocalityAwareConfig {
497 local_zone: "us-west-1".to_string(),
498 fallback_strategy: LocalityFallback::FailLocal,
499 ..Default::default()
500 };
501 let balancer = LocalityAwareBalancer::new(targets, config);
502
503 balancer.report_health("10.0.0.1:8080", false).await;
505 balancer.report_health("10.0.0.2:8080", false).await;
506
507 let result = balancer.select(None).await;
509 assert!(result.is_err());
510 }
511}