sentinel_proxy/upstream/
consistent_hash.rs

1use murmur3::murmur3_32;
2use std::collections::{BTreeMap, HashMap};
3use std::hash::{Hash, Hasher};
4use std::io::Cursor;
5use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
6use std::sync::Arc;
7use tokio::sync::RwLock;
8use xxhash_rust::xxh3::Xxh3;
9
10use super::{LoadBalancer, RequestContext, TargetSelection, UpstreamTarget};
11use async_trait::async_trait;
12use sentinel_common::errors::{SentinelError, SentinelResult};
13use tracing::{debug, info, trace, warn};
14
15/// Hash function types supported by the consistent hash balancer
16#[derive(Debug, Clone, Copy)]
17pub enum HashFunction {
18    Xxh3,
19    Murmur3,
20    DefaultHasher,
21}
22
23/// Configuration for consistent hashing
24#[derive(Debug, Clone)]
25pub struct ConsistentHashConfig {
26    /// Number of virtual nodes per real target
27    pub virtual_nodes: usize,
28    /// Hash function to use
29    pub hash_function: HashFunction,
30    /// Enable bounded loads to prevent overload
31    pub bounded_loads: bool,
32    /// Maximum load factor (1.0 = average load, 1.25 = 25% above average)
33    pub max_load_factor: f64,
34    /// Key extraction function (e.g., from headers, cookies)
35    pub hash_key_extractor: HashKeyExtractor,
36}
37
38impl Default for ConsistentHashConfig {
39    fn default() -> Self {
40        Self {
41            virtual_nodes: 150,
42            hash_function: HashFunction::Xxh3,
43            bounded_loads: true,
44            max_load_factor: 1.25,
45            hash_key_extractor: HashKeyExtractor::ClientIp,
46        }
47    }
48}
49
50/// Defines how to extract the hash key from a request
51#[derive(Clone)]
52pub enum HashKeyExtractor {
53    ClientIp,
54    Header(String),
55    Cookie(String),
56    Custom(Arc<dyn Fn(&RequestContext) -> Option<String> + Send + Sync>),
57}
58
59impl std::fmt::Debug for HashKeyExtractor {
60    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
61        match self {
62            Self::ClientIp => write!(f, "ClientIp"),
63            Self::Header(h) => write!(f, "Header({})", h),
64            Self::Cookie(c) => write!(f, "Cookie({})", c),
65            Self::Custom(_) => write!(f, "Custom"),
66        }
67    }
68}
69
70/// Virtual node in the consistent hash ring
71#[derive(Debug, Clone)]
72struct VirtualNode {
73    /// Hash value of this virtual node
74    hash: u64,
75    /// Index of the real target this virtual node represents
76    target_index: usize,
77    /// Virtual node number for this target
78    virtual_index: usize,
79}
80
81/// Consistent hash load balancer with virtual nodes and bounded loads
82pub struct ConsistentHashBalancer {
83    /// Configuration
84    config: ConsistentHashConfig,
85    /// All upstream targets
86    targets: Vec<UpstreamTarget>,
87    /// Hash ring (sorted by hash value)
88    ring: Arc<RwLock<BTreeMap<u64, VirtualNode>>>,
89    /// Target health status
90    health_status: Arc<RwLock<HashMap<String, bool>>>,
91    /// Active connection count per target (for bounded loads)
92    connection_counts: Vec<Arc<AtomicU64>>,
93    /// Total active connections
94    total_connections: Arc<AtomicU64>,
95    /// Cache for recent hash lookups (hash -> target_index)
96    lookup_cache: Arc<RwLock<HashMap<u64, usize>>>,
97    /// Generation counter for detecting ring changes
98    generation: Arc<AtomicUsize>,
99}
100
101impl ConsistentHashBalancer {
102    pub fn new(targets: Vec<UpstreamTarget>, config: ConsistentHashConfig) -> Self {
103        trace!(
104            target_count = targets.len(),
105            virtual_nodes = config.virtual_nodes,
106            hash_function = ?config.hash_function,
107            bounded_loads = config.bounded_loads,
108            max_load_factor = config.max_load_factor,
109            hash_key_extractor = ?config.hash_key_extractor,
110            "Creating consistent hash balancer"
111        );
112
113        let connection_counts = targets
114            .iter()
115            .map(|_| Arc::new(AtomicU64::new(0)))
116            .collect();
117
118        let balancer = Self {
119            config,
120            targets: targets.clone(),
121            ring: Arc::new(RwLock::new(BTreeMap::new())),
122            health_status: Arc::new(RwLock::new(HashMap::new())),
123            connection_counts,
124            total_connections: Arc::new(AtomicU64::new(0)),
125            lookup_cache: Arc::new(RwLock::new(HashMap::with_capacity(1000))),
126            generation: Arc::new(AtomicUsize::new(0)),
127        };
128
129        // Build initial ring
130        tokio::task::block_in_place(|| {
131            tokio::runtime::Handle::current().block_on(balancer.rebuild_ring());
132        });
133
134        debug!(
135            target_count = targets.len(),
136            "Consistent hash balancer initialized"
137        );
138
139        balancer
140    }
141
142    /// Rebuild the hash ring based on current targets and health
143    async fn rebuild_ring(&self) {
144        trace!(
145            total_targets = self.targets.len(),
146            virtual_nodes_per_target = self.config.virtual_nodes,
147            "Starting hash ring rebuild"
148        );
149
150        let mut new_ring = BTreeMap::new();
151        let health = self.health_status.read().await;
152
153        for (index, target) in self.targets.iter().enumerate() {
154            let target_id = format!("{}:{}", target.address, target.port);
155            let is_healthy = health.get(&target_id).copied().unwrap_or(true);
156
157            if !is_healthy {
158                trace!(
159                    target_id = %target_id,
160                    target_index = index,
161                    "Skipping unhealthy target in ring rebuild"
162                );
163                continue;
164            }
165
166            // Add virtual nodes for this target
167            for vnode in 0..self.config.virtual_nodes {
168                let vnode_key = format!("{}-vnode-{}", target_id, vnode);
169                let hash = self.hash_key(&vnode_key);
170
171                new_ring.insert(
172                    hash,
173                    VirtualNode {
174                        hash,
175                        target_index: index,
176                        virtual_index: vnode,
177                    },
178                );
179            }
180
181            trace!(
182                target_id = %target_id,
183                target_index = index,
184                vnodes_added = self.config.virtual_nodes,
185                "Added virtual nodes for target"
186            );
187        }
188
189        let healthy_count = new_ring
190            .values()
191            .map(|n| n.target_index)
192            .collect::<std::collections::HashSet<_>>()
193            .len();
194
195        if new_ring.is_empty() {
196            warn!("No healthy targets available for consistent hash ring");
197        } else {
198            info!(
199                virtual_nodes = new_ring.len(),
200                healthy_targets = healthy_count,
201                "Rebuilt consistent hash ring"
202            );
203        }
204
205        *self.ring.write().await = new_ring;
206
207        // Clear cache on ring change
208        let cache_size = self.lookup_cache.read().await.len();
209        self.lookup_cache.write().await.clear();
210        let new_generation = self.generation.fetch_add(1, Ordering::SeqCst) + 1;
211
212        trace!(
213            cache_entries_cleared = cache_size,
214            new_generation = new_generation,
215            "Ring rebuild complete, cache cleared"
216        );
217    }
218
219    /// Hash a key using the configured hash function
220    fn hash_key(&self, key: &str) -> u64 {
221        match self.config.hash_function {
222            HashFunction::Xxh3 => {
223                let mut hasher = Xxh3::new();
224                hasher.update(key.as_bytes());
225                hasher.digest()
226            }
227            HashFunction::Murmur3 => {
228                let mut cursor = Cursor::new(key.as_bytes());
229                murmur3_32(&mut cursor, 0).unwrap_or(0) as u64
230            }
231            HashFunction::DefaultHasher => {
232                use std::collections::hash_map::DefaultHasher;
233                let mut hasher = DefaultHasher::new();
234                key.hash(&mut hasher);
235                hasher.finish()
236            }
237        }
238    }
239
240    /// Find target using consistent hashing with optional bounded loads
241    async fn find_target(&self, hash_key: &str) -> Option<usize> {
242        let key_hash = self.hash_key(hash_key);
243
244        trace!(
245            hash_key = %hash_key,
246            key_hash = key_hash,
247            bounded_loads = self.config.bounded_loads,
248            "Finding target for hash key"
249        );
250
251        // Check cache first
252        {
253            let cache = self.lookup_cache.read().await;
254            if let Some(&target_index) = cache.get(&key_hash) {
255                // Verify target is still healthy
256                let health = self.health_status.read().await;
257                let target = &self.targets[target_index];
258                let target_id = format!("{}:{}", target.address, target.port);
259                if health.get(&target_id).copied().unwrap_or(true) {
260                    trace!(
261                        hash_key = %hash_key,
262                        target_index = target_index,
263                        "Cache hit for hash key"
264                    );
265                    return Some(target_index);
266                }
267                trace!(
268                    hash_key = %hash_key,
269                    target_index = target_index,
270                    "Cache hit but target unhealthy"
271                );
272            }
273        }
274
275        let ring = self.ring.read().await;
276
277        if ring.is_empty() {
278            warn!("Hash ring is empty, no targets available");
279            return None;
280        }
281
282        // Find the first virtual node with hash >= key_hash
283        let candidates = if let Some((&_node_hash, vnode)) = ring.range(key_hash..).next() {
284            vec![vnode.clone()]
285        } else {
286            // Wrap around to the first node
287            ring.iter()
288                .next()
289                .map(|(_, vnode)| vec![vnode.clone()])
290                .unwrap_or_default()
291        };
292
293        trace!(
294            hash_key = %hash_key,
295            candidate_count = candidates.len(),
296            "Found candidates on hash ring"
297        );
298
299        // If bounded loads is disabled, return the first candidate
300        if !self.config.bounded_loads {
301            let target_index = candidates.first().map(|n| n.target_index);
302
303            // Update cache
304            if let Some(idx) = target_index {
305                self.lookup_cache.write().await.insert(key_hash, idx);
306                trace!(
307                    hash_key = %hash_key,
308                    target_index = idx,
309                    "Selected target (no bounded loads)"
310                );
311            }
312
313            return target_index;
314        }
315
316        // Bounded loads: check if target is overloaded
317        let avg_load = self.calculate_average_load().await;
318        let max_load = (avg_load * self.config.max_load_factor) as u64;
319
320        trace!(
321            avg_load = avg_load,
322            max_load = max_load,
323            max_load_factor = self.config.max_load_factor,
324            "Checking bounded loads"
325        );
326
327        // Try candidates in order until we find one that's not overloaded
328        for vnode in candidates {
329            let current_load = self.connection_counts[vnode.target_index].load(Ordering::Relaxed);
330
331            trace!(
332                target_index = vnode.target_index,
333                current_load = current_load,
334                max_load = max_load,
335                "Evaluating candidate load"
336            );
337
338            if current_load <= max_load {
339                // Update cache
340                self.lookup_cache
341                    .write()
342                    .await
343                    .insert(key_hash, vnode.target_index);
344                debug!(
345                    hash_key = %hash_key,
346                    target_index = vnode.target_index,
347                    current_load = current_load,
348                    "Selected target within load bounds"
349                );
350                return Some(vnode.target_index);
351            }
352        }
353
354        trace!(
355            hash_key = %hash_key,
356            "All candidates overloaded, falling back to least loaded"
357        );
358
359        // If all candidates are overloaded, find least loaded target
360        self.find_least_loaded_target().await
361    }
362
363    /// Calculate average load across all healthy targets
364    async fn calculate_average_load(&self) -> f64 {
365        let health = self.health_status.read().await;
366        let healthy_count = self
367            .targets
368            .iter()
369            .filter(|t| {
370                let target_id = format!("{}:{}", t.address, t.port);
371                health.get(&target_id).copied().unwrap_or(true)
372            })
373            .count();
374
375        if healthy_count == 0 {
376            return 0.0;
377        }
378
379        let total = self.total_connections.load(Ordering::Relaxed);
380        total as f64 / healthy_count as f64
381    }
382
383    /// Find the least loaded target when all consistent hash candidates are overloaded
384    async fn find_least_loaded_target(&self) -> Option<usize> {
385        trace!("Finding least loaded target as fallback");
386
387        let health = self.health_status.read().await;
388
389        let mut min_load = u64::MAX;
390        let mut best_target = None;
391
392        for (index, target) in self.targets.iter().enumerate() {
393            let target_id = format!("{}:{}", target.address, target.port);
394            if !health.get(&target_id).copied().unwrap_or(true) {
395                trace!(
396                    target_index = index,
397                    target_id = %target_id,
398                    "Skipping unhealthy target"
399                );
400                continue;
401            }
402
403            let load = self.connection_counts[index].load(Ordering::Relaxed);
404            trace!(
405                target_index = index,
406                target_id = %target_id,
407                load = load,
408                "Evaluating target load"
409            );
410
411            if load < min_load {
412                min_load = load;
413                best_target = Some(index);
414            }
415        }
416
417        if let Some(idx) = best_target {
418            debug!(
419                target_index = idx,
420                load = min_load,
421                "Selected least loaded target"
422            );
423        } else {
424            warn!("No healthy targets found for least loaded selection");
425        }
426
427        best_target
428    }
429
430    /// Extract hash key from request context
431    pub fn extract_hash_key(&self, context: &RequestContext) -> Option<String> {
432        let key = match &self.config.hash_key_extractor {
433            HashKeyExtractor::ClientIp => context.client_ip.map(|ip| ip.to_string()),
434            HashKeyExtractor::Header(name) => context.headers.get(name).cloned(),
435            HashKeyExtractor::Cookie(name) => {
436                // Parse cookie header and extract specific cookie
437                context.headers.get("cookie").and_then(|cookies| {
438                    cookies.split(';').find_map(|cookie| {
439                        let parts: Vec<&str> = cookie.trim().splitn(2, '=').collect();
440                        if parts.len() == 2 && parts[0] == name {
441                            Some(parts[1].to_string())
442                        } else {
443                            None
444                        }
445                    })
446                })
447            }
448            HashKeyExtractor::Custom(extractor) => extractor(context),
449        };
450
451        trace!(
452            extractor = ?self.config.hash_key_extractor,
453            key_found = key.is_some(),
454            "Extracted hash key from request"
455        );
456
457        key
458    }
459
460    /// Track connection acquisition
461    pub fn acquire_connection(&self, target_index: usize) {
462        let count = self.connection_counts[target_index].fetch_add(1, Ordering::Relaxed) + 1;
463        let total = self.total_connections.fetch_add(1, Ordering::Relaxed) + 1;
464        trace!(
465            target_index = target_index,
466            target_connections = count,
467            total_connections = total,
468            "Acquired connection"
469        );
470    }
471
472    /// Track connection release
473    pub fn release_connection(&self, target_index: usize) {
474        let count = self.connection_counts[target_index].fetch_sub(1, Ordering::Relaxed) - 1;
475        let total = self.total_connections.fetch_sub(1, Ordering::Relaxed) - 1;
476        trace!(
477            target_index = target_index,
478            target_connections = count,
479            total_connections = total,
480            "Released connection"
481        );
482    }
483}
484
485#[async_trait]
486impl LoadBalancer for ConsistentHashBalancer {
487    async fn select(&self, context: Option<&RequestContext>) -> SentinelResult<TargetSelection> {
488        trace!(
489            has_context = context.is_some(),
490            "Consistent hash select called"
491        );
492
493        // Extract hash key from context or use random fallback
494        let (hash_key, used_random) = context
495            .and_then(|ctx| self.extract_hash_key(ctx))
496            .map(|k| (k, false))
497            .unwrap_or_else(|| {
498                // Generate random key for requests without proper hash key
499                use rand::Rng;
500                let mut rng = rand::thread_rng();
501                let key = format!("random-{}", rng.gen::<u64>());
502                trace!(random_key = %key, "Generated random hash key (no context key)");
503                (key, true)
504            });
505
506        let target_index = self.find_target(&hash_key).await.ok_or_else(|| {
507            warn!("No healthy upstream targets available");
508            SentinelError::NoHealthyUpstream
509        })?;
510
511        let target = &self.targets[target_index];
512
513        // Track connection for bounded loads
514        if self.config.bounded_loads {
515            self.acquire_connection(target_index);
516        }
517
518        let current_load = self.connection_counts[target_index].load(Ordering::Relaxed);
519
520        debug!(
521            target = %format!("{}:{}", target.address, target.port),
522            hash_key = %hash_key,
523            target_index = target_index,
524            current_load = current_load,
525            used_random_key = used_random,
526            "Consistent hash selected target"
527        );
528
529        Ok(TargetSelection {
530            address: format!("{}:{}", target.address, target.port),
531            weight: target.weight,
532            metadata: {
533                let mut meta = HashMap::new();
534                meta.insert("hash_key".to_string(), hash_key);
535                meta.insert("target_index".to_string(), target_index.to_string());
536                meta.insert("load".to_string(), current_load.to_string());
537                meta.insert("algorithm".to_string(), "consistent_hash".to_string());
538                meta
539            },
540        })
541    }
542
543    async fn report_health(&self, address: &str, healthy: bool) {
544        trace!(
545            address = %address,
546            healthy = healthy,
547            "Reporting target health"
548        );
549
550        let mut health = self.health_status.write().await;
551        let previous = health.insert(address.to_string(), healthy);
552
553        // Rebuild ring if health status changed
554        if previous != Some(healthy) {
555            info!(
556                address = %address,
557                previous_status = ?previous,
558                new_status = healthy,
559                "Target health changed, rebuilding ring"
560            );
561            drop(health); // Release lock before rebuild
562            self.rebuild_ring().await;
563        }
564    }
565
566    async fn healthy_targets(&self) -> Vec<String> {
567        let health = self.health_status.read().await;
568        let targets: Vec<String> = self
569            .targets
570            .iter()
571            .filter_map(|t| {
572                let target_id = format!("{}:{}", t.address, t.port);
573                if health.get(&target_id).copied().unwrap_or(true) {
574                    Some(target_id)
575                } else {
576                    None
577                }
578            })
579            .collect();
580
581        trace!(
582            total_targets = self.targets.len(),
583            healthy_count = targets.len(),
584            "Retrieved healthy targets"
585        );
586
587        targets
588    }
589
590    /// Release connection when request completes
591    async fn release(&self, selection: &TargetSelection) {
592        if self.config.bounded_loads {
593            if let Some(index_str) = selection.metadata.get("target_index") {
594                if let Ok(index) = index_str.parse::<usize>() {
595                    trace!(
596                        target_index = index,
597                        address = %selection.address,
598                        "Releasing connection for bounded loads"
599                    );
600                    self.release_connection(index);
601                }
602            }
603        }
604    }
605}
606
607#[cfg(test)]
608mod tests {
609    use super::*;
610
611    fn create_test_targets(count: usize) -> Vec<UpstreamTarget> {
612        (0..count)
613            .map(|i| UpstreamTarget {
614                address: format!("10.0.0.{}", i + 1),
615                port: 8080,
616                weight: 100,
617            })
618            .collect()
619    }
620
621    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
622    async fn test_consistent_distribution() {
623        let targets = create_test_targets(5);
624        let config = ConsistentHashConfig {
625            virtual_nodes: 100,
626            bounded_loads: false,
627            ..Default::default()
628        };
629
630        let balancer = ConsistentHashBalancer::new(targets.clone(), config);
631
632        // Test distribution of 10000 keys
633        let mut distribution = vec![0u64; targets.len()];
634
635        for i in 0..10000 {
636            let context = RequestContext {
637                client_ip: Some(format!("192.168.1.{}:1234", i % 256).parse().unwrap()),
638                headers: HashMap::new(),
639                path: "/".to_string(),
640                method: "GET".to_string(),
641            };
642
643            if let Ok(selection) = balancer.select(Some(&context)).await {
644                if let Some(index_str) = selection.metadata.get("target_index") {
645                    if let Ok(index) = index_str.parse::<usize>() {
646                        distribution[index] += 1;
647                    }
648                }
649            }
650        }
651
652        // Check that distribution is relatively even (within 50% of average)
653        let avg = 10000.0 / targets.len() as f64;
654        for count in distribution {
655            let ratio = count as f64 / avg;
656            assert!(
657                ratio > 0.5 && ratio < 1.5,
658                "Distribution too skewed: {}",
659                ratio
660            );
661        }
662    }
663
664    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
665    async fn test_bounded_loads() {
666        let targets = create_test_targets(3);
667        let config = ConsistentHashConfig {
668            virtual_nodes: 50,
669            bounded_loads: true,
670            max_load_factor: 1.2,
671            ..Default::default()
672        };
673
674        let balancer = ConsistentHashBalancer::new(targets.clone(), config);
675
676        // Simulate high load on first target
677        balancer.connection_counts[0].store(100, Ordering::Relaxed);
678        balancer.total_connections.store(110, Ordering::Relaxed);
679
680        // New request should avoid overloaded target
681        let context = RequestContext {
682            client_ip: Some("192.168.1.1:1234".parse().unwrap()),
683            headers: HashMap::new(),
684            path: "/".to_string(),
685            method: "GET".to_string(),
686        };
687
688        let selection = balancer.select(Some(&context)).await.unwrap();
689        let index = selection
690            .metadata
691            .get("target_index")
692            .and_then(|s| s.parse::<usize>().ok())
693            .unwrap();
694
695        // Should not select the overloaded target (index 0)
696        assert_ne!(index, 0);
697    }
698
699    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
700    async fn test_ring_rebuild_on_health_change() {
701        let targets = create_test_targets(3);
702        let config = ConsistentHashConfig::default();
703
704        let balancer = ConsistentHashBalancer::new(targets.clone(), config);
705
706        let initial_generation = balancer.generation.load(Ordering::SeqCst);
707
708        // Mark a target as unhealthy
709        balancer.report_health("10.0.0.1:8080", false).await;
710
711        // Generation should have incremented
712        let new_generation = balancer.generation.load(Ordering::SeqCst);
713        assert_eq!(new_generation, initial_generation + 1);
714
715        // Unhealthy target should not be selected
716        let healthy = balancer.healthy_targets().await;
717        assert_eq!(healthy.len(), 2);
718        assert!(!healthy.contains(&"10.0.0.1:8080".to_string()));
719    }
720}