1use murmur3::murmur3_32;
2use std::collections::{BTreeMap, HashMap};
3use std::hash::{Hash, Hasher};
4use std::io::Cursor;
5use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
6use std::sync::Arc;
7use tokio::sync::RwLock;
8use xxhash_rust::xxh3::Xxh3;
9
10use super::{LoadBalancer, RequestContext, TargetSelection, UpstreamTarget};
11use async_trait::async_trait;
12use sentinel_common::errors::{SentinelError, SentinelResult};
13use tracing::{debug, info, trace, warn};
14
15#[derive(Debug, Clone, Copy)]
17pub enum HashFunction {
18 Xxh3,
19 Murmur3,
20 DefaultHasher,
21}
22
23#[derive(Debug, Clone)]
25pub struct ConsistentHashConfig {
26 pub virtual_nodes: usize,
28 pub hash_function: HashFunction,
30 pub bounded_loads: bool,
32 pub max_load_factor: f64,
34 pub hash_key_extractor: HashKeyExtractor,
36}
37
38impl Default for ConsistentHashConfig {
39 fn default() -> Self {
40 Self {
41 virtual_nodes: 150,
42 hash_function: HashFunction::Xxh3,
43 bounded_loads: true,
44 max_load_factor: 1.25,
45 hash_key_extractor: HashKeyExtractor::ClientIp,
46 }
47 }
48}
49
50#[derive(Clone)]
52pub enum HashKeyExtractor {
53 ClientIp,
54 Header(String),
55 Cookie(String),
56 Custom(Arc<dyn Fn(&RequestContext) -> Option<String> + Send + Sync>),
57}
58
59impl std::fmt::Debug for HashKeyExtractor {
60 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
61 match self {
62 Self::ClientIp => write!(f, "ClientIp"),
63 Self::Header(h) => write!(f, "Header({})", h),
64 Self::Cookie(c) => write!(f, "Cookie({})", c),
65 Self::Custom(_) => write!(f, "Custom"),
66 }
67 }
68}
69
70#[derive(Debug, Clone)]
72struct VirtualNode {
73 hash: u64,
75 target_index: usize,
77 virtual_index: usize,
79}
80
81pub struct ConsistentHashBalancer {
83 config: ConsistentHashConfig,
85 targets: Vec<UpstreamTarget>,
87 ring: Arc<RwLock<BTreeMap<u64, VirtualNode>>>,
89 health_status: Arc<RwLock<HashMap<String, bool>>>,
91 connection_counts: Vec<Arc<AtomicU64>>,
93 total_connections: Arc<AtomicU64>,
95 lookup_cache: Arc<RwLock<HashMap<u64, usize>>>,
97 generation: Arc<AtomicUsize>,
99}
100
101impl ConsistentHashBalancer {
102 pub fn new(targets: Vec<UpstreamTarget>, config: ConsistentHashConfig) -> Self {
103 trace!(
104 target_count = targets.len(),
105 virtual_nodes = config.virtual_nodes,
106 hash_function = ?config.hash_function,
107 bounded_loads = config.bounded_loads,
108 max_load_factor = config.max_load_factor,
109 hash_key_extractor = ?config.hash_key_extractor,
110 "Creating consistent hash balancer"
111 );
112
113 let connection_counts = targets
114 .iter()
115 .map(|_| Arc::new(AtomicU64::new(0)))
116 .collect();
117
118 let balancer = Self {
119 config,
120 targets: targets.clone(),
121 ring: Arc::new(RwLock::new(BTreeMap::new())),
122 health_status: Arc::new(RwLock::new(HashMap::new())),
123 connection_counts,
124 total_connections: Arc::new(AtomicU64::new(0)),
125 lookup_cache: Arc::new(RwLock::new(HashMap::with_capacity(1000))),
126 generation: Arc::new(AtomicUsize::new(0)),
127 };
128
129 tokio::task::block_in_place(|| {
131 tokio::runtime::Handle::current().block_on(balancer.rebuild_ring());
132 });
133
134 debug!(
135 target_count = targets.len(),
136 "Consistent hash balancer initialized"
137 );
138
139 balancer
140 }
141
142 async fn rebuild_ring(&self) {
144 trace!(
145 total_targets = self.targets.len(),
146 virtual_nodes_per_target = self.config.virtual_nodes,
147 "Starting hash ring rebuild"
148 );
149
150 let mut new_ring = BTreeMap::new();
151 let health = self.health_status.read().await;
152
153 for (index, target) in self.targets.iter().enumerate() {
154 let target_id = format!("{}:{}", target.address, target.port);
155 let is_healthy = health.get(&target_id).copied().unwrap_or(true);
156
157 if !is_healthy {
158 trace!(
159 target_id = %target_id,
160 target_index = index,
161 "Skipping unhealthy target in ring rebuild"
162 );
163 continue;
164 }
165
166 for vnode in 0..self.config.virtual_nodes {
168 let vnode_key = format!("{}-vnode-{}", target_id, vnode);
169 let hash = self.hash_key(&vnode_key);
170
171 new_ring.insert(
172 hash,
173 VirtualNode {
174 hash,
175 target_index: index,
176 virtual_index: vnode,
177 },
178 );
179 }
180
181 trace!(
182 target_id = %target_id,
183 target_index = index,
184 vnodes_added = self.config.virtual_nodes,
185 "Added virtual nodes for target"
186 );
187 }
188
189 let healthy_count = new_ring
190 .values()
191 .map(|n| n.target_index)
192 .collect::<std::collections::HashSet<_>>()
193 .len();
194
195 if new_ring.is_empty() {
196 warn!("No healthy targets available for consistent hash ring");
197 } else {
198 info!(
199 virtual_nodes = new_ring.len(),
200 healthy_targets = healthy_count,
201 "Rebuilt consistent hash ring"
202 );
203 }
204
205 *self.ring.write().await = new_ring;
206
207 let cache_size = self.lookup_cache.read().await.len();
209 self.lookup_cache.write().await.clear();
210 let new_generation = self.generation.fetch_add(1, Ordering::SeqCst) + 1;
211
212 trace!(
213 cache_entries_cleared = cache_size,
214 new_generation = new_generation,
215 "Ring rebuild complete, cache cleared"
216 );
217 }
218
219 fn hash_key(&self, key: &str) -> u64 {
221 match self.config.hash_function {
222 HashFunction::Xxh3 => {
223 let mut hasher = Xxh3::new();
224 hasher.update(key.as_bytes());
225 hasher.digest()
226 }
227 HashFunction::Murmur3 => {
228 let mut cursor = Cursor::new(key.as_bytes());
229 murmur3_32(&mut cursor, 0).unwrap_or(0) as u64
230 }
231 HashFunction::DefaultHasher => {
232 use std::collections::hash_map::DefaultHasher;
233 let mut hasher = DefaultHasher::new();
234 key.hash(&mut hasher);
235 hasher.finish()
236 }
237 }
238 }
239
240 async fn find_target(&self, hash_key: &str) -> Option<usize> {
242 let key_hash = self.hash_key(hash_key);
243
244 trace!(
245 hash_key = %hash_key,
246 key_hash = key_hash,
247 bounded_loads = self.config.bounded_loads,
248 "Finding target for hash key"
249 );
250
251 {
253 let cache = self.lookup_cache.read().await;
254 if let Some(&target_index) = cache.get(&key_hash) {
255 let health = self.health_status.read().await;
257 let target = &self.targets[target_index];
258 let target_id = format!("{}:{}", target.address, target.port);
259 if health.get(&target_id).copied().unwrap_or(true) {
260 trace!(
261 hash_key = %hash_key,
262 target_index = target_index,
263 "Cache hit for hash key"
264 );
265 return Some(target_index);
266 }
267 trace!(
268 hash_key = %hash_key,
269 target_index = target_index,
270 "Cache hit but target unhealthy"
271 );
272 }
273 }
274
275 let ring = self.ring.read().await;
276
277 if ring.is_empty() {
278 warn!("Hash ring is empty, no targets available");
279 return None;
280 }
281
282 let candidates = if let Some((&_node_hash, vnode)) = ring.range(key_hash..).next() {
284 vec![vnode.clone()]
285 } else {
286 ring.iter()
288 .next()
289 .map(|(_, vnode)| vec![vnode.clone()])
290 .unwrap_or_default()
291 };
292
293 trace!(
294 hash_key = %hash_key,
295 candidate_count = candidates.len(),
296 "Found candidates on hash ring"
297 );
298
299 if !self.config.bounded_loads {
301 let target_index = candidates.first().map(|n| n.target_index);
302
303 if let Some(idx) = target_index {
305 self.lookup_cache.write().await.insert(key_hash, idx);
306 trace!(
307 hash_key = %hash_key,
308 target_index = idx,
309 "Selected target (no bounded loads)"
310 );
311 }
312
313 return target_index;
314 }
315
316 let avg_load = self.calculate_average_load().await;
318 let max_load = (avg_load * self.config.max_load_factor) as u64;
319
320 trace!(
321 avg_load = avg_load,
322 max_load = max_load,
323 max_load_factor = self.config.max_load_factor,
324 "Checking bounded loads"
325 );
326
327 for vnode in candidates {
329 let current_load = self.connection_counts[vnode.target_index].load(Ordering::Relaxed);
330
331 trace!(
332 target_index = vnode.target_index,
333 current_load = current_load,
334 max_load = max_load,
335 "Evaluating candidate load"
336 );
337
338 if current_load <= max_load {
339 self.lookup_cache
341 .write()
342 .await
343 .insert(key_hash, vnode.target_index);
344 debug!(
345 hash_key = %hash_key,
346 target_index = vnode.target_index,
347 current_load = current_load,
348 "Selected target within load bounds"
349 );
350 return Some(vnode.target_index);
351 }
352 }
353
354 trace!(
355 hash_key = %hash_key,
356 "All candidates overloaded, falling back to least loaded"
357 );
358
359 self.find_least_loaded_target().await
361 }
362
363 async fn calculate_average_load(&self) -> f64 {
365 let health = self.health_status.read().await;
366 let healthy_count = self
367 .targets
368 .iter()
369 .filter(|t| {
370 let target_id = format!("{}:{}", t.address, t.port);
371 health.get(&target_id).copied().unwrap_or(true)
372 })
373 .count();
374
375 if healthy_count == 0 {
376 return 0.0;
377 }
378
379 let total = self.total_connections.load(Ordering::Relaxed);
380 total as f64 / healthy_count as f64
381 }
382
383 async fn find_least_loaded_target(&self) -> Option<usize> {
385 trace!("Finding least loaded target as fallback");
386
387 let health = self.health_status.read().await;
388
389 let mut min_load = u64::MAX;
390 let mut best_target = None;
391
392 for (index, target) in self.targets.iter().enumerate() {
393 let target_id = format!("{}:{}", target.address, target.port);
394 if !health.get(&target_id).copied().unwrap_or(true) {
395 trace!(
396 target_index = index,
397 target_id = %target_id,
398 "Skipping unhealthy target"
399 );
400 continue;
401 }
402
403 let load = self.connection_counts[index].load(Ordering::Relaxed);
404 trace!(
405 target_index = index,
406 target_id = %target_id,
407 load = load,
408 "Evaluating target load"
409 );
410
411 if load < min_load {
412 min_load = load;
413 best_target = Some(index);
414 }
415 }
416
417 if let Some(idx) = best_target {
418 debug!(
419 target_index = idx,
420 load = min_load,
421 "Selected least loaded target"
422 );
423 } else {
424 warn!("No healthy targets found for least loaded selection");
425 }
426
427 best_target
428 }
429
430 pub fn extract_hash_key(&self, context: &RequestContext) -> Option<String> {
432 let key = match &self.config.hash_key_extractor {
433 HashKeyExtractor::ClientIp => context.client_ip.map(|ip| ip.to_string()),
434 HashKeyExtractor::Header(name) => context.headers.get(name).cloned(),
435 HashKeyExtractor::Cookie(name) => {
436 context.headers.get("cookie").and_then(|cookies| {
438 cookies.split(';').find_map(|cookie| {
439 let parts: Vec<&str> = cookie.trim().splitn(2, '=').collect();
440 if parts.len() == 2 && parts[0] == name {
441 Some(parts[1].to_string())
442 } else {
443 None
444 }
445 })
446 })
447 }
448 HashKeyExtractor::Custom(extractor) => extractor(context),
449 };
450
451 trace!(
452 extractor = ?self.config.hash_key_extractor,
453 key_found = key.is_some(),
454 "Extracted hash key from request"
455 );
456
457 key
458 }
459
460 pub fn acquire_connection(&self, target_index: usize) {
462 let count = self.connection_counts[target_index].fetch_add(1, Ordering::Relaxed) + 1;
463 let total = self.total_connections.fetch_add(1, Ordering::Relaxed) + 1;
464 trace!(
465 target_index = target_index,
466 target_connections = count,
467 total_connections = total,
468 "Acquired connection"
469 );
470 }
471
472 pub fn release_connection(&self, target_index: usize) {
474 let count = self.connection_counts[target_index].fetch_sub(1, Ordering::Relaxed) - 1;
475 let total = self.total_connections.fetch_sub(1, Ordering::Relaxed) - 1;
476 trace!(
477 target_index = target_index,
478 target_connections = count,
479 total_connections = total,
480 "Released connection"
481 );
482 }
483}
484
485#[async_trait]
486impl LoadBalancer for ConsistentHashBalancer {
487 async fn select(&self, context: Option<&RequestContext>) -> SentinelResult<TargetSelection> {
488 trace!(
489 has_context = context.is_some(),
490 "Consistent hash select called"
491 );
492
493 let (hash_key, used_random) = context
495 .and_then(|ctx| self.extract_hash_key(ctx))
496 .map(|k| (k, false))
497 .unwrap_or_else(|| {
498 use rand::Rng;
500 let mut rng = rand::thread_rng();
501 let key = format!("random-{}", rng.gen::<u64>());
502 trace!(random_key = %key, "Generated random hash key (no context key)");
503 (key, true)
504 });
505
506 let target_index = self.find_target(&hash_key).await.ok_or_else(|| {
507 warn!("No healthy upstream targets available");
508 SentinelError::NoHealthyUpstream
509 })?;
510
511 let target = &self.targets[target_index];
512
513 if self.config.bounded_loads {
515 self.acquire_connection(target_index);
516 }
517
518 let current_load = self.connection_counts[target_index].load(Ordering::Relaxed);
519
520 debug!(
521 target = %format!("{}:{}", target.address, target.port),
522 hash_key = %hash_key,
523 target_index = target_index,
524 current_load = current_load,
525 used_random_key = used_random,
526 "Consistent hash selected target"
527 );
528
529 Ok(TargetSelection {
530 address: format!("{}:{}", target.address, target.port),
531 weight: target.weight,
532 metadata: {
533 let mut meta = HashMap::new();
534 meta.insert("hash_key".to_string(), hash_key);
535 meta.insert("target_index".to_string(), target_index.to_string());
536 meta.insert("load".to_string(), current_load.to_string());
537 meta.insert("algorithm".to_string(), "consistent_hash".to_string());
538 meta
539 },
540 })
541 }
542
543 async fn report_health(&self, address: &str, healthy: bool) {
544 trace!(
545 address = %address,
546 healthy = healthy,
547 "Reporting target health"
548 );
549
550 let mut health = self.health_status.write().await;
551 let previous = health.insert(address.to_string(), healthy);
552
553 if previous != Some(healthy) {
555 info!(
556 address = %address,
557 previous_status = ?previous,
558 new_status = healthy,
559 "Target health changed, rebuilding ring"
560 );
561 drop(health); self.rebuild_ring().await;
563 }
564 }
565
566 async fn healthy_targets(&self) -> Vec<String> {
567 let health = self.health_status.read().await;
568 let targets: Vec<String> = self
569 .targets
570 .iter()
571 .filter_map(|t| {
572 let target_id = format!("{}:{}", t.address, t.port);
573 if health.get(&target_id).copied().unwrap_or(true) {
574 Some(target_id)
575 } else {
576 None
577 }
578 })
579 .collect();
580
581 trace!(
582 total_targets = self.targets.len(),
583 healthy_count = targets.len(),
584 "Retrieved healthy targets"
585 );
586
587 targets
588 }
589
590 async fn release(&self, selection: &TargetSelection) {
592 if self.config.bounded_loads {
593 if let Some(index_str) = selection.metadata.get("target_index") {
594 if let Ok(index) = index_str.parse::<usize>() {
595 trace!(
596 target_index = index,
597 address = %selection.address,
598 "Releasing connection for bounded loads"
599 );
600 self.release_connection(index);
601 }
602 }
603 }
604 }
605}
606
607#[cfg(test)]
608mod tests {
609 use super::*;
610
611 fn create_test_targets(count: usize) -> Vec<UpstreamTarget> {
612 (0..count)
613 .map(|i| UpstreamTarget {
614 address: format!("10.0.0.{}", i + 1),
615 port: 8080,
616 weight: 100,
617 })
618 .collect()
619 }
620
621 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
622 async fn test_consistent_distribution() {
623 let targets = create_test_targets(5);
624 let config = ConsistentHashConfig {
625 virtual_nodes: 100,
626 bounded_loads: false,
627 ..Default::default()
628 };
629
630 let balancer = ConsistentHashBalancer::new(targets.clone(), config);
631
632 let mut distribution = vec![0u64; targets.len()];
634
635 for i in 0..10000 {
636 let context = RequestContext {
637 client_ip: Some(format!("192.168.1.{}:1234", i % 256).parse().unwrap()),
638 headers: HashMap::new(),
639 path: "/".to_string(),
640 method: "GET".to_string(),
641 };
642
643 if let Ok(selection) = balancer.select(Some(&context)).await {
644 if let Some(index_str) = selection.metadata.get("target_index") {
645 if let Ok(index) = index_str.parse::<usize>() {
646 distribution[index] += 1;
647 }
648 }
649 }
650 }
651
652 let avg = 10000.0 / targets.len() as f64;
654 for count in distribution {
655 let ratio = count as f64 / avg;
656 assert!(
657 ratio > 0.5 && ratio < 1.5,
658 "Distribution too skewed: {}",
659 ratio
660 );
661 }
662 }
663
664 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
665 async fn test_bounded_loads() {
666 let targets = create_test_targets(3);
667 let config = ConsistentHashConfig {
668 virtual_nodes: 50,
669 bounded_loads: true,
670 max_load_factor: 1.2,
671 ..Default::default()
672 };
673
674 let balancer = ConsistentHashBalancer::new(targets.clone(), config);
675
676 balancer.connection_counts[0].store(100, Ordering::Relaxed);
678 balancer.total_connections.store(110, Ordering::Relaxed);
679
680 let context = RequestContext {
682 client_ip: Some("192.168.1.1:1234".parse().unwrap()),
683 headers: HashMap::new(),
684 path: "/".to_string(),
685 method: "GET".to_string(),
686 };
687
688 let selection = balancer.select(Some(&context)).await.unwrap();
689 let index = selection
690 .metadata
691 .get("target_index")
692 .and_then(|s| s.parse::<usize>().ok())
693 .unwrap();
694
695 assert_ne!(index, 0);
697 }
698
699 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
700 async fn test_ring_rebuild_on_health_change() {
701 let targets = create_test_targets(3);
702 let config = ConsistentHashConfig::default();
703
704 let balancer = ConsistentHashBalancer::new(targets.clone(), config);
705
706 let initial_generation = balancer.generation.load(Ordering::SeqCst);
707
708 balancer.report_health("10.0.0.1:8080", false).await;
710
711 let new_generation = balancer.generation.load(Ordering::SeqCst);
713 assert_eq!(new_generation, initial_generation + 1);
714
715 let healthy = balancer.healthy_targets().await;
717 assert_eq!(healthy.len(), 2);
718 assert!(!healthy.contains(&"10.0.0.1:8080".to_string()));
719 }
720}