1use murmur3::murmur3_32;
2use std::collections::{BTreeMap, HashMap};
3use std::hash::{Hash, Hasher};
4use std::io::Cursor;
5use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
6use std::sync::Arc;
7use tokio::sync::RwLock;
8use xxhash_rust::xxh3::Xxh3;
9
10use super::{LoadBalancer, RequestContext, TargetSelection, UpstreamTarget};
11use sentinel_common::errors::{SentinelError, SentinelResult};
12use async_trait::async_trait;
13use tracing::{debug, info, trace, warn};
14
15#[derive(Debug, Clone, Copy)]
17pub enum HashFunction {
18 Xxh3,
19 Murmur3,
20 DefaultHasher,
21}
22
23#[derive(Debug, Clone)]
25pub struct ConsistentHashConfig {
26 pub virtual_nodes: usize,
28 pub hash_function: HashFunction,
30 pub bounded_loads: bool,
32 pub max_load_factor: f64,
34 pub hash_key_extractor: HashKeyExtractor,
36}
37
38impl Default for ConsistentHashConfig {
39 fn default() -> Self {
40 Self {
41 virtual_nodes: 150,
42 hash_function: HashFunction::Xxh3,
43 bounded_loads: true,
44 max_load_factor: 1.25,
45 hash_key_extractor: HashKeyExtractor::ClientIp,
46 }
47 }
48}
49
50#[derive(Clone)]
52pub enum HashKeyExtractor {
53 ClientIp,
54 Header(String),
55 Cookie(String),
56 Custom(Arc<dyn Fn(&RequestContext) -> Option<String> + Send + Sync>),
57}
58
59impl std::fmt::Debug for HashKeyExtractor {
60 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
61 match self {
62 Self::ClientIp => write!(f, "ClientIp"),
63 Self::Header(h) => write!(f, "Header({})", h),
64 Self::Cookie(c) => write!(f, "Cookie({})", c),
65 Self::Custom(_) => write!(f, "Custom"),
66 }
67 }
68}
69
70#[derive(Debug, Clone)]
72struct VirtualNode {
73 hash: u64,
75 target_index: usize,
77 virtual_index: usize,
79}
80
81pub struct ConsistentHashBalancer {
83 config: ConsistentHashConfig,
85 targets: Vec<UpstreamTarget>,
87 ring: Arc<RwLock<BTreeMap<u64, VirtualNode>>>,
89 health_status: Arc<RwLock<HashMap<String, bool>>>,
91 connection_counts: Vec<Arc<AtomicU64>>,
93 total_connections: Arc<AtomicU64>,
95 lookup_cache: Arc<RwLock<HashMap<u64, usize>>>,
97 generation: Arc<AtomicUsize>,
99}
100
101impl ConsistentHashBalancer {
102 pub fn new(targets: Vec<UpstreamTarget>, config: ConsistentHashConfig) -> Self {
103 trace!(
104 target_count = targets.len(),
105 virtual_nodes = config.virtual_nodes,
106 hash_function = ?config.hash_function,
107 bounded_loads = config.bounded_loads,
108 max_load_factor = config.max_load_factor,
109 hash_key_extractor = ?config.hash_key_extractor,
110 "Creating consistent hash balancer"
111 );
112
113 let connection_counts = targets
114 .iter()
115 .map(|_| Arc::new(AtomicU64::new(0)))
116 .collect();
117
118 let balancer = Self {
119 config,
120 targets: targets.clone(),
121 ring: Arc::new(RwLock::new(BTreeMap::new())),
122 health_status: Arc::new(RwLock::new(HashMap::new())),
123 connection_counts,
124 total_connections: Arc::new(AtomicU64::new(0)),
125 lookup_cache: Arc::new(RwLock::new(HashMap::with_capacity(1000))),
126 generation: Arc::new(AtomicUsize::new(0)),
127 };
128
129 tokio::task::block_in_place(|| {
131 tokio::runtime::Handle::current().block_on(balancer.rebuild_ring());
132 });
133
134 debug!(
135 target_count = targets.len(),
136 "Consistent hash balancer initialized"
137 );
138
139 balancer
140 }
141
142 async fn rebuild_ring(&self) {
144 trace!(
145 total_targets = self.targets.len(),
146 virtual_nodes_per_target = self.config.virtual_nodes,
147 "Starting hash ring rebuild"
148 );
149
150 let mut new_ring = BTreeMap::new();
151 let health = self.health_status.read().await;
152
153 for (index, target) in self.targets.iter().enumerate() {
154 let target_id = format!("{}:{}", target.address, target.port);
155 let is_healthy = health.get(&target_id).copied().unwrap_or(true);
156
157 if !is_healthy {
158 trace!(
159 target_id = %target_id,
160 target_index = index,
161 "Skipping unhealthy target in ring rebuild"
162 );
163 continue;
164 }
165
166 for vnode in 0..self.config.virtual_nodes {
168 let vnode_key = format!("{}-vnode-{}", target_id, vnode);
169 let hash = self.hash_key(&vnode_key);
170
171 new_ring.insert(
172 hash,
173 VirtualNode {
174 hash,
175 target_index: index,
176 virtual_index: vnode,
177 },
178 );
179 }
180
181 trace!(
182 target_id = %target_id,
183 target_index = index,
184 vnodes_added = self.config.virtual_nodes,
185 "Added virtual nodes for target"
186 );
187 }
188
189 let healthy_count = new_ring
190 .values()
191 .map(|n| n.target_index)
192 .collect::<std::collections::HashSet<_>>()
193 .len();
194
195 if new_ring.is_empty() {
196 warn!("No healthy targets available for consistent hash ring");
197 } else {
198 info!(
199 virtual_nodes = new_ring.len(),
200 healthy_targets = healthy_count,
201 "Rebuilt consistent hash ring"
202 );
203 }
204
205 *self.ring.write().await = new_ring;
206
207 let cache_size = self.lookup_cache.read().await.len();
209 self.lookup_cache.write().await.clear();
210 let new_generation = self.generation.fetch_add(1, Ordering::SeqCst) + 1;
211
212 trace!(
213 cache_entries_cleared = cache_size,
214 new_generation = new_generation,
215 "Ring rebuild complete, cache cleared"
216 );
217 }
218
219 fn hash_key(&self, key: &str) -> u64 {
221 match self.config.hash_function {
222 HashFunction::Xxh3 => {
223 let mut hasher = Xxh3::new();
224 hasher.update(key.as_bytes());
225 hasher.digest()
226 }
227 HashFunction::Murmur3 => {
228 let mut cursor = Cursor::new(key.as_bytes());
229 murmur3_32(&mut cursor, 0).unwrap_or(0) as u64
230 }
231 HashFunction::DefaultHasher => {
232 use std::collections::hash_map::DefaultHasher;
233 let mut hasher = DefaultHasher::new();
234 key.hash(&mut hasher);
235 hasher.finish()
236 }
237 }
238 }
239
240 async fn find_target(&self, hash_key: &str) -> Option<usize> {
242 let key_hash = self.hash_key(hash_key);
243
244 trace!(
245 hash_key = %hash_key,
246 key_hash = key_hash,
247 bounded_loads = self.config.bounded_loads,
248 "Finding target for hash key"
249 );
250
251 {
253 let cache = self.lookup_cache.read().await;
254 if let Some(&target_index) = cache.get(&key_hash) {
255 let health = self.health_status.read().await;
257 let target = &self.targets[target_index];
258 let target_id = format!("{}:{}", target.address, target.port);
259 if health.get(&target_id).copied().unwrap_or(true) {
260 trace!(
261 hash_key = %hash_key,
262 target_index = target_index,
263 "Cache hit for hash key"
264 );
265 return Some(target_index);
266 }
267 trace!(
268 hash_key = %hash_key,
269 target_index = target_index,
270 "Cache hit but target unhealthy"
271 );
272 }
273 }
274
275 let ring = self.ring.read().await;
276
277 if ring.is_empty() {
278 warn!("Hash ring is empty, no targets available");
279 return None;
280 }
281
282 let candidates = if let Some((&_node_hash, vnode)) = ring.range(key_hash..).next() {
284 vec![vnode.clone()]
285 } else {
286 ring.iter()
288 .next()
289 .map(|(_, vnode)| vec![vnode.clone()])
290 .unwrap_or_default()
291 };
292
293 trace!(
294 hash_key = %hash_key,
295 candidate_count = candidates.len(),
296 "Found candidates on hash ring"
297 );
298
299 if !self.config.bounded_loads {
301 let target_index = candidates.first().map(|n| n.target_index);
302
303 if let Some(idx) = target_index {
305 self.lookup_cache.write().await.insert(key_hash, idx);
306 trace!(
307 hash_key = %hash_key,
308 target_index = idx,
309 "Selected target (no bounded loads)"
310 );
311 }
312
313 return target_index;
314 }
315
316 let avg_load = self.calculate_average_load().await;
318 let max_load = (avg_load * self.config.max_load_factor) as u64;
319
320 trace!(
321 avg_load = avg_load,
322 max_load = max_load,
323 max_load_factor = self.config.max_load_factor,
324 "Checking bounded loads"
325 );
326
327 for vnode in candidates {
329 let current_load = self.connection_counts[vnode.target_index].load(Ordering::Relaxed);
330
331 trace!(
332 target_index = vnode.target_index,
333 current_load = current_load,
334 max_load = max_load,
335 "Evaluating candidate load"
336 );
337
338 if current_load <= max_load {
339 self.lookup_cache
341 .write()
342 .await
343 .insert(key_hash, vnode.target_index);
344 debug!(
345 hash_key = %hash_key,
346 target_index = vnode.target_index,
347 current_load = current_load,
348 "Selected target within load bounds"
349 );
350 return Some(vnode.target_index);
351 }
352 }
353
354 trace!(
355 hash_key = %hash_key,
356 "All candidates overloaded, falling back to least loaded"
357 );
358
359 self.find_least_loaded_target().await
361 }
362
363 async fn calculate_average_load(&self) -> f64 {
365 let health = self.health_status.read().await;
366 let healthy_count = self
367 .targets
368 .iter()
369 .filter(|t| {
370 let target_id = format!("{}:{}", t.address, t.port);
371 health.get(&target_id).copied().unwrap_or(true)
372 })
373 .count();
374
375 if healthy_count == 0 {
376 return 0.0;
377 }
378
379 let total = self.total_connections.load(Ordering::Relaxed);
380 total as f64 / healthy_count as f64
381 }
382
383 async fn find_least_loaded_target(&self) -> Option<usize> {
385 trace!("Finding least loaded target as fallback");
386
387 let health = self.health_status.read().await;
388
389 let mut min_load = u64::MAX;
390 let mut best_target = None;
391
392 for (index, target) in self.targets.iter().enumerate() {
393 let target_id = format!("{}:{}", target.address, target.port);
394 if !health.get(&target_id).copied().unwrap_or(true) {
395 trace!(
396 target_index = index,
397 target_id = %target_id,
398 "Skipping unhealthy target"
399 );
400 continue;
401 }
402
403 let load = self.connection_counts[index].load(Ordering::Relaxed);
404 trace!(
405 target_index = index,
406 target_id = %target_id,
407 load = load,
408 "Evaluating target load"
409 );
410
411 if load < min_load {
412 min_load = load;
413 best_target = Some(index);
414 }
415 }
416
417 if let Some(idx) = best_target {
418 debug!(
419 target_index = idx,
420 load = min_load,
421 "Selected least loaded target"
422 );
423 } else {
424 warn!("No healthy targets found for least loaded selection");
425 }
426
427 best_target
428 }
429
430 pub fn extract_hash_key(&self, context: &RequestContext) -> Option<String> {
432 let key = match &self.config.hash_key_extractor {
433 HashKeyExtractor::ClientIp => context.client_ip.map(|ip| ip.to_string()),
434 HashKeyExtractor::Header(name) => context.headers.get(name).cloned(),
435 HashKeyExtractor::Cookie(name) => {
436 context.headers.get("cookie").and_then(|cookies| {
438 cookies.split(';').find_map(|cookie| {
439 let parts: Vec<&str> = cookie.trim().splitn(2, '=').collect();
440 if parts.len() == 2 && parts[0] == name {
441 Some(parts[1].to_string())
442 } else {
443 None
444 }
445 })
446 })
447 }
448 HashKeyExtractor::Custom(extractor) => extractor(context),
449 };
450
451 trace!(
452 extractor = ?self.config.hash_key_extractor,
453 key_found = key.is_some(),
454 "Extracted hash key from request"
455 );
456
457 key
458 }
459
460 pub fn acquire_connection(&self, target_index: usize) {
462 let count = self.connection_counts[target_index].fetch_add(1, Ordering::Relaxed) + 1;
463 let total = self.total_connections.fetch_add(1, Ordering::Relaxed) + 1;
464 trace!(
465 target_index = target_index,
466 target_connections = count,
467 total_connections = total,
468 "Acquired connection"
469 );
470 }
471
472 pub fn release_connection(&self, target_index: usize) {
474 let count = self.connection_counts[target_index].fetch_sub(1, Ordering::Relaxed) - 1;
475 let total = self.total_connections.fetch_sub(1, Ordering::Relaxed) - 1;
476 trace!(
477 target_index = target_index,
478 target_connections = count,
479 total_connections = total,
480 "Released connection"
481 );
482 }
483}
484
485#[async_trait]
486impl LoadBalancer for ConsistentHashBalancer {
487 async fn select(&self, context: Option<&RequestContext>) -> SentinelResult<TargetSelection> {
488 trace!(
489 has_context = context.is_some(),
490 "Consistent hash select called"
491 );
492
493 let (hash_key, used_random) = context
495 .and_then(|ctx| self.extract_hash_key(ctx))
496 .map(|k| (k, false))
497 .unwrap_or_else(|| {
498 use rand::Rng;
500 let mut rng = rand::thread_rng();
501 let key = format!("random-{}", rng.gen::<u64>());
502 trace!(random_key = %key, "Generated random hash key (no context key)");
503 (key, true)
504 });
505
506 let target_index = self
507 .find_target(&hash_key)
508 .await
509 .ok_or_else(|| {
510 warn!("No healthy upstream targets available");
511 SentinelError::NoHealthyUpstream
512 })?;
513
514 let target = &self.targets[target_index];
515
516 if self.config.bounded_loads {
518 self.acquire_connection(target_index);
519 }
520
521 let current_load = self.connection_counts[target_index].load(Ordering::Relaxed);
522
523 debug!(
524 target = %format!("{}:{}", target.address, target.port),
525 hash_key = %hash_key,
526 target_index = target_index,
527 current_load = current_load,
528 used_random_key = used_random,
529 "Consistent hash selected target"
530 );
531
532 Ok(TargetSelection {
533 address: format!("{}:{}", target.address, target.port),
534 weight: target.weight,
535 metadata: {
536 let mut meta = HashMap::new();
537 meta.insert("hash_key".to_string(), hash_key);
538 meta.insert("target_index".to_string(), target_index.to_string());
539 meta.insert("load".to_string(), current_load.to_string());
540 meta.insert("algorithm".to_string(), "consistent_hash".to_string());
541 meta
542 },
543 })
544 }
545
546 async fn report_health(&self, address: &str, healthy: bool) {
547 trace!(
548 address = %address,
549 healthy = healthy,
550 "Reporting target health"
551 );
552
553 let mut health = self.health_status.write().await;
554 let previous = health.insert(address.to_string(), healthy);
555
556 if previous != Some(healthy) {
558 info!(
559 address = %address,
560 previous_status = ?previous,
561 new_status = healthy,
562 "Target health changed, rebuilding ring"
563 );
564 drop(health); self.rebuild_ring().await;
566 }
567 }
568
569 async fn healthy_targets(&self) -> Vec<String> {
570 let health = self.health_status.read().await;
571 let targets: Vec<String> = self.targets
572 .iter()
573 .filter_map(|t| {
574 let target_id = format!("{}:{}", t.address, t.port);
575 if health.get(&target_id).copied().unwrap_or(true) {
576 Some(target_id)
577 } else {
578 None
579 }
580 })
581 .collect();
582
583 trace!(
584 total_targets = self.targets.len(),
585 healthy_count = targets.len(),
586 "Retrieved healthy targets"
587 );
588
589 targets
590 }
591
592 async fn release(&self, selection: &TargetSelection) {
594 if self.config.bounded_loads {
595 if let Some(index_str) = selection.metadata.get("target_index") {
596 if let Ok(index) = index_str.parse::<usize>() {
597 trace!(
598 target_index = index,
599 address = %selection.address,
600 "Releasing connection for bounded loads"
601 );
602 self.release_connection(index);
603 }
604 }
605 }
606 }
607}
608
609#[cfg(test)]
610mod tests {
611 use super::*;
612
613 fn create_test_targets(count: usize) -> Vec<UpstreamTarget> {
614 (0..count)
615 .map(|i| UpstreamTarget {
616 address: format!("10.0.0.{}", i + 1),
617 port: 8080,
618 weight: 100,
619 })
620 .collect()
621 }
622
623 #[tokio::test]
624 async fn test_consistent_distribution() {
625 let targets = create_test_targets(5);
626 let config = ConsistentHashConfig {
627 virtual_nodes: 100,
628 bounded_loads: false,
629 ..Default::default()
630 };
631
632 let balancer = ConsistentHashBalancer::new(targets.clone(), config);
633
634 let mut distribution = vec![0u64; targets.len()];
636
637 for i in 0..10000 {
638 let context = RequestContext {
639 client_ip: Some(format!("192.168.1.{}:1234", i % 256).parse().unwrap()),
640 headers: HashMap::new(),
641 path: "/".to_string(),
642 method: "GET".to_string(),
643 };
644
645 if let Ok(selection) = balancer.select(Some(&context)).await {
646 if let Some(index_str) = selection.metadata.get("target_index") {
647 if let Ok(index) = index_str.parse::<usize>() {
648 distribution[index] += 1;
649 }
650 }
651 }
652 }
653
654 let avg = 10000.0 / targets.len() as f64;
656 for count in distribution {
657 let ratio = count as f64 / avg;
658 assert!(
659 ratio > 0.5 && ratio < 1.5,
660 "Distribution too skewed: {}",
661 ratio
662 );
663 }
664 }
665
666 #[tokio::test]
667 async fn test_bounded_loads() {
668 let targets = create_test_targets(3);
669 let config = ConsistentHashConfig {
670 virtual_nodes: 50,
671 bounded_loads: true,
672 max_load_factor: 1.2,
673 ..Default::default()
674 };
675
676 let balancer = ConsistentHashBalancer::new(targets.clone(), config);
677
678 balancer.connection_counts[0].store(100, Ordering::Relaxed);
680 balancer.total_connections.store(110, Ordering::Relaxed);
681
682 let context = RequestContext {
684 client_ip: Some("192.168.1.1:1234".parse().unwrap()),
685 headers: HashMap::new(),
686 path: "/".to_string(),
687 method: "GET".to_string(),
688 };
689
690 let selection = balancer.select(Some(&context)).await.unwrap();
691 let index = selection
692 .metadata
693 .get("target_index")
694 .and_then(|s| s.parse::<usize>().ok())
695 .unwrap();
696
697 assert_ne!(index, 0);
699 }
700
701 #[tokio::test]
702 async fn test_ring_rebuild_on_health_change() {
703 let targets = create_test_targets(3);
704 let config = ConsistentHashConfig::default();
705
706 let balancer = ConsistentHashBalancer::new(targets.clone(), config);
707
708 let initial_generation = balancer.generation.load(Ordering::SeqCst);
709
710 balancer.report_health("10.0.0.1:8080", false).await;
712
713 let new_generation = balancer.generation.load(Ordering::SeqCst);
715 assert_eq!(new_generation, initial_generation + 1);
716
717 let healthy = balancer.healthy_targets().await;
719 assert_eq!(healthy.len(), 2);
720 assert!(!healthy.contains(&"10.0.0.1:8080".to_string()));
721 }
722}