1use murmur3::murmur3_32;
2use std::collections::{BTreeMap, HashMap};
3use std::hash::{Hash, Hasher};
4use std::io::Cursor;
5use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
6use std::sync::Arc;
7use tokio::sync::RwLock;
8use xxhash_rust::xxh3::Xxh3;
9
10use super::{LoadBalancer, RequestContext, TargetSelection, UpstreamTarget};
11use sentinel_common::errors::{SentinelError, SentinelResult};
12use async_trait::async_trait;
13use tracing::{debug, info, warn};
14
15#[derive(Debug, Clone, Copy)]
17pub enum HashFunction {
18 Xxh3,
19 Murmur3,
20 DefaultHasher,
21}
22
23#[derive(Debug, Clone)]
25pub struct ConsistentHashConfig {
26 pub virtual_nodes: usize,
28 pub hash_function: HashFunction,
30 pub bounded_loads: bool,
32 pub max_load_factor: f64,
34 pub hash_key_extractor: HashKeyExtractor,
36}
37
38impl Default for ConsistentHashConfig {
39 fn default() -> Self {
40 Self {
41 virtual_nodes: 150,
42 hash_function: HashFunction::Xxh3,
43 bounded_loads: true,
44 max_load_factor: 1.25,
45 hash_key_extractor: HashKeyExtractor::ClientIp,
46 }
47 }
48}
49
50#[derive(Clone)]
52pub enum HashKeyExtractor {
53 ClientIp,
54 Header(String),
55 Cookie(String),
56 Custom(Arc<dyn Fn(&RequestContext) -> Option<String> + Send + Sync>),
57}
58
59impl std::fmt::Debug for HashKeyExtractor {
60 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
61 match self {
62 Self::ClientIp => write!(f, "ClientIp"),
63 Self::Header(h) => write!(f, "Header({})", h),
64 Self::Cookie(c) => write!(f, "Cookie({})", c),
65 Self::Custom(_) => write!(f, "Custom"),
66 }
67 }
68}
69
70#[derive(Debug, Clone)]
72struct VirtualNode {
73 hash: u64,
75 target_index: usize,
77 virtual_index: usize,
79}
80
81pub struct ConsistentHashBalancer {
83 config: ConsistentHashConfig,
85 targets: Vec<UpstreamTarget>,
87 ring: Arc<RwLock<BTreeMap<u64, VirtualNode>>>,
89 health_status: Arc<RwLock<HashMap<String, bool>>>,
91 connection_counts: Vec<Arc<AtomicU64>>,
93 total_connections: Arc<AtomicU64>,
95 lookup_cache: Arc<RwLock<HashMap<u64, usize>>>,
97 generation: Arc<AtomicUsize>,
99}
100
101impl ConsistentHashBalancer {
102 pub fn new(targets: Vec<UpstreamTarget>, config: ConsistentHashConfig) -> Self {
103 let connection_counts = targets
104 .iter()
105 .map(|_| Arc::new(AtomicU64::new(0)))
106 .collect();
107
108 let balancer = Self {
109 config,
110 targets: targets.clone(),
111 ring: Arc::new(RwLock::new(BTreeMap::new())),
112 health_status: Arc::new(RwLock::new(HashMap::new())),
113 connection_counts,
114 total_connections: Arc::new(AtomicU64::new(0)),
115 lookup_cache: Arc::new(RwLock::new(HashMap::with_capacity(1000))),
116 generation: Arc::new(AtomicUsize::new(0)),
117 };
118
119 tokio::task::block_in_place(|| {
121 tokio::runtime::Handle::current().block_on(balancer.rebuild_ring());
122 });
123
124 balancer
125 }
126
127 async fn rebuild_ring(&self) {
129 let mut new_ring = BTreeMap::new();
130 let health = self.health_status.read().await;
131
132 for (index, target) in self.targets.iter().enumerate() {
133 let target_id = format!("{}:{}", target.address, target.port);
134 let is_healthy = health.get(&target_id).copied().unwrap_or(true);
135
136 if !is_healthy {
137 continue;
138 }
139
140 for vnode in 0..self.config.virtual_nodes {
142 let vnode_key = format!("{}-vnode-{}", target_id, vnode);
143 let hash = self.hash_key(&vnode_key);
144
145 new_ring.insert(
146 hash,
147 VirtualNode {
148 hash,
149 target_index: index,
150 virtual_index: vnode,
151 },
152 );
153 }
154 }
155
156 if new_ring.is_empty() {
157 warn!("No healthy targets available for consistent hash ring");
158 } else {
159 info!(
160 "Rebuilt consistent hash ring with {} virtual nodes from {} healthy targets",
161 new_ring.len(),
162 new_ring
163 .values()
164 .map(|n| n.target_index)
165 .collect::<std::collections::HashSet<_>>()
166 .len()
167 );
168 }
169
170 *self.ring.write().await = new_ring;
171
172 self.lookup_cache.write().await.clear();
174 self.generation.fetch_add(1, Ordering::SeqCst);
175 }
176
177 fn hash_key(&self, key: &str) -> u64 {
179 match self.config.hash_function {
180 HashFunction::Xxh3 => {
181 let mut hasher = Xxh3::new();
182 hasher.update(key.as_bytes());
183 hasher.digest()
184 }
185 HashFunction::Murmur3 => {
186 let mut cursor = Cursor::new(key.as_bytes());
187 murmur3_32(&mut cursor, 0).unwrap_or(0) as u64
188 }
189 HashFunction::DefaultHasher => {
190 use std::collections::hash_map::DefaultHasher;
191 let mut hasher = DefaultHasher::new();
192 key.hash(&mut hasher);
193 hasher.finish()
194 }
195 }
196 }
197
198 async fn find_target(&self, hash_key: &str) -> Option<usize> {
200 let key_hash = self.hash_key(hash_key);
201
202 {
204 let cache = self.lookup_cache.read().await;
205 if let Some(&target_index) = cache.get(&key_hash) {
206 let health = self.health_status.read().await;
208 let target = &self.targets[target_index];
209 let target_id = format!("{}:{}", target.address, target.port);
210 if health.get(&target_id).copied().unwrap_or(true) {
211 return Some(target_index);
212 }
213 }
214 }
215
216 let ring = self.ring.read().await;
217
218 if ring.is_empty() {
219 return None;
220 }
221
222 let candidates = if let Some((&_node_hash, vnode)) = ring.range(key_hash..).next() {
224 vec![vnode.clone()]
225 } else {
226 ring.iter()
228 .next()
229 .map(|(_, vnode)| vec![vnode.clone()])
230 .unwrap_or_default()
231 };
232
233 if !self.config.bounded_loads {
235 let target_index = candidates.first().map(|n| n.target_index);
236
237 if let Some(idx) = target_index {
239 self.lookup_cache.write().await.insert(key_hash, idx);
240 }
241
242 return target_index;
243 }
244
245 let avg_load = self.calculate_average_load().await;
247 let max_load = (avg_load * self.config.max_load_factor) as u64;
248
249 for vnode in candidates {
251 let current_load = self.connection_counts[vnode.target_index].load(Ordering::Relaxed);
252
253 if current_load <= max_load {
254 self.lookup_cache
256 .write()
257 .await
258 .insert(key_hash, vnode.target_index);
259 return Some(vnode.target_index);
260 }
261 }
262
263 self.find_least_loaded_target().await
265 }
266
267 async fn calculate_average_load(&self) -> f64 {
269 let health = self.health_status.read().await;
270 let healthy_count = self
271 .targets
272 .iter()
273 .filter(|t| {
274 let target_id = format!("{}:{}", t.address, t.port);
275 health.get(&target_id).copied().unwrap_or(true)
276 })
277 .count();
278
279 if healthy_count == 0 {
280 return 0.0;
281 }
282
283 let total = self.total_connections.load(Ordering::Relaxed);
284 total as f64 / healthy_count as f64
285 }
286
287 async fn find_least_loaded_target(&self) -> Option<usize> {
289 let health = self.health_status.read().await;
290
291 let mut min_load = u64::MAX;
292 let mut best_target = None;
293
294 for (index, target) in self.targets.iter().enumerate() {
295 let target_id = format!("{}:{}", target.address, target.port);
296 if !health.get(&target_id).copied().unwrap_or(true) {
297 continue;
298 }
299
300 let load = self.connection_counts[index].load(Ordering::Relaxed);
301 if load < min_load {
302 min_load = load;
303 best_target = Some(index);
304 }
305 }
306
307 best_target
308 }
309
310 pub fn extract_hash_key(&self, context: &RequestContext) -> Option<String> {
312 match &self.config.hash_key_extractor {
313 HashKeyExtractor::ClientIp => context.client_ip.map(|ip| ip.to_string()),
314 HashKeyExtractor::Header(name) => context.headers.get(name).cloned(),
315 HashKeyExtractor::Cookie(name) => {
316 context.headers.get("cookie").and_then(|cookies| {
318 cookies.split(';').find_map(|cookie| {
319 let parts: Vec<&str> = cookie.trim().splitn(2, '=').collect();
320 if parts.len() == 2 && parts[0] == name {
321 Some(parts[1].to_string())
322 } else {
323 None
324 }
325 })
326 })
327 }
328 HashKeyExtractor::Custom(extractor) => extractor(context),
329 }
330 }
331
332 pub fn acquire_connection(&self, target_index: usize) {
334 self.connection_counts[target_index].fetch_add(1, Ordering::Relaxed);
335 self.total_connections.fetch_add(1, Ordering::Relaxed);
336 }
337
338 pub fn release_connection(&self, target_index: usize) {
340 self.connection_counts[target_index].fetch_sub(1, Ordering::Relaxed);
341 self.total_connections.fetch_sub(1, Ordering::Relaxed);
342 }
343}
344
345#[async_trait]
346impl LoadBalancer for ConsistentHashBalancer {
347 async fn select(&self, context: Option<&RequestContext>) -> SentinelResult<TargetSelection> {
348 let hash_key = context
350 .and_then(|ctx| self.extract_hash_key(ctx))
351 .unwrap_or_else(|| {
352 use rand::Rng;
354 let mut rng = rand::thread_rng();
355 format!("random-{}", rng.gen::<u64>())
356 });
357
358 let target_index = self
359 .find_target(&hash_key)
360 .await
361 .ok_or_else(|| SentinelError::NoHealthyUpstream)?;
362
363 let target = &self.targets[target_index];
364
365 if self.config.bounded_loads {
367 self.acquire_connection(target_index);
368 }
369
370 debug!(
371 "Selected target {} for hash key {} (index: {})",
372 format!("{}:{}", target.address, target.port),
373 hash_key,
374 target_index
375 );
376
377 Ok(TargetSelection {
378 address: format!("{}:{}", target.address, target.port),
379 weight: target.weight,
380 metadata: {
381 let mut meta = HashMap::new();
382 meta.insert("hash_key".to_string(), hash_key);
383 meta.insert("target_index".to_string(), target_index.to_string());
384 meta.insert(
385 "load".to_string(),
386 self.connection_counts[target_index]
387 .load(Ordering::Relaxed)
388 .to_string(),
389 );
390 meta
391 },
392 })
393 }
394
395 async fn report_health(&self, address: &str, healthy: bool) {
396 let mut health = self.health_status.write().await;
397 let previous = health.insert(address.to_string(), healthy);
398
399 if previous != Some(healthy) {
401 info!(
402 "Target {} health changed from {:?} to {}",
403 address, previous, healthy
404 );
405 self.rebuild_ring().await;
406 }
407 }
408
409 async fn healthy_targets(&self) -> Vec<String> {
410 let health = self.health_status.read().await;
411 self.targets
412 .iter()
413 .filter_map(|t| {
414 let target_id = format!("{}:{}", t.address, t.port);
415 if health.get(&target_id).copied().unwrap_or(true) {
416 Some(target_id)
417 } else {
418 None
419 }
420 })
421 .collect()
422 }
423
424 async fn release(&self, selection: &TargetSelection) {
426 if self.config.bounded_loads {
427 if let Some(index_str) = selection.metadata.get("target_index") {
428 if let Ok(index) = index_str.parse::<usize>() {
429 self.release_connection(index);
430 }
431 }
432 }
433 }
434}
435
436#[cfg(test)]
437mod tests {
438 use super::*;
439
440 fn create_test_targets(count: usize) -> Vec<UpstreamTarget> {
441 (0..count)
442 .map(|i| UpstreamTarget {
443 address: format!("10.0.0.{}", i + 1),
444 port: 8080,
445 weight: 100,
446 })
447 .collect()
448 }
449
450 #[tokio::test]
451 async fn test_consistent_distribution() {
452 let targets = create_test_targets(5);
453 let config = ConsistentHashConfig {
454 virtual_nodes: 100,
455 bounded_loads: false,
456 ..Default::default()
457 };
458
459 let balancer = ConsistentHashBalancer::new(targets.clone(), config);
460
461 let mut distribution = vec![0u64; targets.len()];
463
464 for i in 0..10000 {
465 let context = RequestContext {
466 client_ip: Some(format!("192.168.1.{}:1234", i % 256).parse().unwrap()),
467 headers: HashMap::new(),
468 path: "/".to_string(),
469 method: "GET".to_string(),
470 };
471
472 if let Ok(selection) = balancer.select(Some(&context)).await {
473 if let Some(index_str) = selection.metadata.get("target_index") {
474 if let Ok(index) = index_str.parse::<usize>() {
475 distribution[index] += 1;
476 }
477 }
478 }
479 }
480
481 let avg = 10000.0 / targets.len() as f64;
483 for count in distribution {
484 let ratio = count as f64 / avg;
485 assert!(
486 ratio > 0.5 && ratio < 1.5,
487 "Distribution too skewed: {}",
488 ratio
489 );
490 }
491 }
492
493 #[tokio::test]
494 async fn test_bounded_loads() {
495 let targets = create_test_targets(3);
496 let config = ConsistentHashConfig {
497 virtual_nodes: 50,
498 bounded_loads: true,
499 max_load_factor: 1.2,
500 ..Default::default()
501 };
502
503 let balancer = ConsistentHashBalancer::new(targets.clone(), config);
504
505 balancer.connection_counts[0].store(100, Ordering::Relaxed);
507 balancer.total_connections.store(110, Ordering::Relaxed);
508
509 let context = RequestContext {
511 client_ip: Some("192.168.1.1:1234".parse().unwrap()),
512 headers: HashMap::new(),
513 path: "/".to_string(),
514 method: "GET".to_string(),
515 };
516
517 let selection = balancer.select(Some(&context)).await.unwrap();
518 let index = selection
519 .metadata
520 .get("target_index")
521 .and_then(|s| s.parse::<usize>().ok())
522 .unwrap();
523
524 assert_ne!(index, 0);
526 }
527
528 #[tokio::test]
529 async fn test_ring_rebuild_on_health_change() {
530 let targets = create_test_targets(3);
531 let config = ConsistentHashConfig::default();
532
533 let balancer = ConsistentHashBalancer::new(targets.clone(), config);
534
535 let initial_generation = balancer.generation.load(Ordering::SeqCst);
536
537 balancer.report_health("10.0.0.1:8080", false).await;
539
540 let new_generation = balancer.generation.load(Ordering::SeqCst);
542 assert_eq!(new_generation, initial_generation + 1);
543
544 let healthy = balancer.healthy_targets().await;
546 assert_eq!(healthy.len(), 2);
547 assert!(!healthy.contains(&"10.0.0.1:8080".to_string()));
548 }
549}