sentinel_proxy/upstream/
consistent_hash.rs

1use murmur3::murmur3_32;
2use std::collections::{BTreeMap, HashMap};
3use std::hash::{Hash, Hasher};
4use std::io::Cursor;
5use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
6use std::sync::Arc;
7use tokio::sync::RwLock;
8use xxhash_rust::xxh3::Xxh3;
9
10use super::{LoadBalancer, RequestContext, TargetSelection, UpstreamTarget};
11use sentinel_common::errors::{SentinelError, SentinelResult};
12use async_trait::async_trait;
13use tracing::{debug, info, trace, warn};
14
15/// Hash function types supported by the consistent hash balancer
16#[derive(Debug, Clone, Copy)]
17pub enum HashFunction {
18    Xxh3,
19    Murmur3,
20    DefaultHasher,
21}
22
23/// Configuration for consistent hashing
24#[derive(Debug, Clone)]
25pub struct ConsistentHashConfig {
26    /// Number of virtual nodes per real target
27    pub virtual_nodes: usize,
28    /// Hash function to use
29    pub hash_function: HashFunction,
30    /// Enable bounded loads to prevent overload
31    pub bounded_loads: bool,
32    /// Maximum load factor (1.0 = average load, 1.25 = 25% above average)
33    pub max_load_factor: f64,
34    /// Key extraction function (e.g., from headers, cookies)
35    pub hash_key_extractor: HashKeyExtractor,
36}
37
38impl Default for ConsistentHashConfig {
39    fn default() -> Self {
40        Self {
41            virtual_nodes: 150,
42            hash_function: HashFunction::Xxh3,
43            bounded_loads: true,
44            max_load_factor: 1.25,
45            hash_key_extractor: HashKeyExtractor::ClientIp,
46        }
47    }
48}
49
50/// Defines how to extract the hash key from a request
51#[derive(Clone)]
52pub enum HashKeyExtractor {
53    ClientIp,
54    Header(String),
55    Cookie(String),
56    Custom(Arc<dyn Fn(&RequestContext) -> Option<String> + Send + Sync>),
57}
58
59impl std::fmt::Debug for HashKeyExtractor {
60    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
61        match self {
62            Self::ClientIp => write!(f, "ClientIp"),
63            Self::Header(h) => write!(f, "Header({})", h),
64            Self::Cookie(c) => write!(f, "Cookie({})", c),
65            Self::Custom(_) => write!(f, "Custom"),
66        }
67    }
68}
69
70/// Virtual node in the consistent hash ring
71#[derive(Debug, Clone)]
72struct VirtualNode {
73    /// Hash value of this virtual node
74    hash: u64,
75    /// Index of the real target this virtual node represents
76    target_index: usize,
77    /// Virtual node number for this target
78    virtual_index: usize,
79}
80
81/// Consistent hash load balancer with virtual nodes and bounded loads
82pub struct ConsistentHashBalancer {
83    /// Configuration
84    config: ConsistentHashConfig,
85    /// All upstream targets
86    targets: Vec<UpstreamTarget>,
87    /// Hash ring (sorted by hash value)
88    ring: Arc<RwLock<BTreeMap<u64, VirtualNode>>>,
89    /// Target health status
90    health_status: Arc<RwLock<HashMap<String, bool>>>,
91    /// Active connection count per target (for bounded loads)
92    connection_counts: Vec<Arc<AtomicU64>>,
93    /// Total active connections
94    total_connections: Arc<AtomicU64>,
95    /// Cache for recent hash lookups (hash -> target_index)
96    lookup_cache: Arc<RwLock<HashMap<u64, usize>>>,
97    /// Generation counter for detecting ring changes
98    generation: Arc<AtomicUsize>,
99}
100
101impl ConsistentHashBalancer {
102    pub fn new(targets: Vec<UpstreamTarget>, config: ConsistentHashConfig) -> Self {
103        trace!(
104            target_count = targets.len(),
105            virtual_nodes = config.virtual_nodes,
106            hash_function = ?config.hash_function,
107            bounded_loads = config.bounded_loads,
108            max_load_factor = config.max_load_factor,
109            hash_key_extractor = ?config.hash_key_extractor,
110            "Creating consistent hash balancer"
111        );
112
113        let connection_counts = targets
114            .iter()
115            .map(|_| Arc::new(AtomicU64::new(0)))
116            .collect();
117
118        let balancer = Self {
119            config,
120            targets: targets.clone(),
121            ring: Arc::new(RwLock::new(BTreeMap::new())),
122            health_status: Arc::new(RwLock::new(HashMap::new())),
123            connection_counts,
124            total_connections: Arc::new(AtomicU64::new(0)),
125            lookup_cache: Arc::new(RwLock::new(HashMap::with_capacity(1000))),
126            generation: Arc::new(AtomicUsize::new(0)),
127        };
128
129        // Build initial ring
130        tokio::task::block_in_place(|| {
131            tokio::runtime::Handle::current().block_on(balancer.rebuild_ring());
132        });
133
134        debug!(
135            target_count = targets.len(),
136            "Consistent hash balancer initialized"
137        );
138
139        balancer
140    }
141
142    /// Rebuild the hash ring based on current targets and health
143    async fn rebuild_ring(&self) {
144        trace!(
145            total_targets = self.targets.len(),
146            virtual_nodes_per_target = self.config.virtual_nodes,
147            "Starting hash ring rebuild"
148        );
149
150        let mut new_ring = BTreeMap::new();
151        let health = self.health_status.read().await;
152
153        for (index, target) in self.targets.iter().enumerate() {
154            let target_id = format!("{}:{}", target.address, target.port);
155            let is_healthy = health.get(&target_id).copied().unwrap_or(true);
156
157            if !is_healthy {
158                trace!(
159                    target_id = %target_id,
160                    target_index = index,
161                    "Skipping unhealthy target in ring rebuild"
162                );
163                continue;
164            }
165
166            // Add virtual nodes for this target
167            for vnode in 0..self.config.virtual_nodes {
168                let vnode_key = format!("{}-vnode-{}", target_id, vnode);
169                let hash = self.hash_key(&vnode_key);
170
171                new_ring.insert(
172                    hash,
173                    VirtualNode {
174                        hash,
175                        target_index: index,
176                        virtual_index: vnode,
177                    },
178                );
179            }
180
181            trace!(
182                target_id = %target_id,
183                target_index = index,
184                vnodes_added = self.config.virtual_nodes,
185                "Added virtual nodes for target"
186            );
187        }
188
189        let healthy_count = new_ring
190            .values()
191            .map(|n| n.target_index)
192            .collect::<std::collections::HashSet<_>>()
193            .len();
194
195        if new_ring.is_empty() {
196            warn!("No healthy targets available for consistent hash ring");
197        } else {
198            info!(
199                virtual_nodes = new_ring.len(),
200                healthy_targets = healthy_count,
201                "Rebuilt consistent hash ring"
202            );
203        }
204
205        *self.ring.write().await = new_ring;
206
207        // Clear cache on ring change
208        let cache_size = self.lookup_cache.read().await.len();
209        self.lookup_cache.write().await.clear();
210        let new_generation = self.generation.fetch_add(1, Ordering::SeqCst) + 1;
211
212        trace!(
213            cache_entries_cleared = cache_size,
214            new_generation = new_generation,
215            "Ring rebuild complete, cache cleared"
216        );
217    }
218
219    /// Hash a key using the configured hash function
220    fn hash_key(&self, key: &str) -> u64 {
221        match self.config.hash_function {
222            HashFunction::Xxh3 => {
223                let mut hasher = Xxh3::new();
224                hasher.update(key.as_bytes());
225                hasher.digest()
226            }
227            HashFunction::Murmur3 => {
228                let mut cursor = Cursor::new(key.as_bytes());
229                murmur3_32(&mut cursor, 0).unwrap_or(0) as u64
230            }
231            HashFunction::DefaultHasher => {
232                use std::collections::hash_map::DefaultHasher;
233                let mut hasher = DefaultHasher::new();
234                key.hash(&mut hasher);
235                hasher.finish()
236            }
237        }
238    }
239
240    /// Find target using consistent hashing with optional bounded loads
241    async fn find_target(&self, hash_key: &str) -> Option<usize> {
242        let key_hash = self.hash_key(hash_key);
243
244        trace!(
245            hash_key = %hash_key,
246            key_hash = key_hash,
247            bounded_loads = self.config.bounded_loads,
248            "Finding target for hash key"
249        );
250
251        // Check cache first
252        {
253            let cache = self.lookup_cache.read().await;
254            if let Some(&target_index) = cache.get(&key_hash) {
255                // Verify target is still healthy
256                let health = self.health_status.read().await;
257                let target = &self.targets[target_index];
258                let target_id = format!("{}:{}", target.address, target.port);
259                if health.get(&target_id).copied().unwrap_or(true) {
260                    trace!(
261                        hash_key = %hash_key,
262                        target_index = target_index,
263                        "Cache hit for hash key"
264                    );
265                    return Some(target_index);
266                }
267                trace!(
268                    hash_key = %hash_key,
269                    target_index = target_index,
270                    "Cache hit but target unhealthy"
271                );
272            }
273        }
274
275        let ring = self.ring.read().await;
276
277        if ring.is_empty() {
278            warn!("Hash ring is empty, no targets available");
279            return None;
280        }
281
282        // Find the first virtual node with hash >= key_hash
283        let candidates = if let Some((&_node_hash, vnode)) = ring.range(key_hash..).next() {
284            vec![vnode.clone()]
285        } else {
286            // Wrap around to the first node
287            ring.iter()
288                .next()
289                .map(|(_, vnode)| vec![vnode.clone()])
290                .unwrap_or_default()
291        };
292
293        trace!(
294            hash_key = %hash_key,
295            candidate_count = candidates.len(),
296            "Found candidates on hash ring"
297        );
298
299        // If bounded loads is disabled, return the first candidate
300        if !self.config.bounded_loads {
301            let target_index = candidates.first().map(|n| n.target_index);
302
303            // Update cache
304            if let Some(idx) = target_index {
305                self.lookup_cache.write().await.insert(key_hash, idx);
306                trace!(
307                    hash_key = %hash_key,
308                    target_index = idx,
309                    "Selected target (no bounded loads)"
310                );
311            }
312
313            return target_index;
314        }
315
316        // Bounded loads: check if target is overloaded
317        let avg_load = self.calculate_average_load().await;
318        let max_load = (avg_load * self.config.max_load_factor) as u64;
319
320        trace!(
321            avg_load = avg_load,
322            max_load = max_load,
323            max_load_factor = self.config.max_load_factor,
324            "Checking bounded loads"
325        );
326
327        // Try candidates in order until we find one that's not overloaded
328        for vnode in candidates {
329            let current_load = self.connection_counts[vnode.target_index].load(Ordering::Relaxed);
330
331            trace!(
332                target_index = vnode.target_index,
333                current_load = current_load,
334                max_load = max_load,
335                "Evaluating candidate load"
336            );
337
338            if current_load <= max_load {
339                // Update cache
340                self.lookup_cache
341                    .write()
342                    .await
343                    .insert(key_hash, vnode.target_index);
344                debug!(
345                    hash_key = %hash_key,
346                    target_index = vnode.target_index,
347                    current_load = current_load,
348                    "Selected target within load bounds"
349                );
350                return Some(vnode.target_index);
351            }
352        }
353
354        trace!(
355            hash_key = %hash_key,
356            "All candidates overloaded, falling back to least loaded"
357        );
358
359        // If all candidates are overloaded, find least loaded target
360        self.find_least_loaded_target().await
361    }
362
363    /// Calculate average load across all healthy targets
364    async fn calculate_average_load(&self) -> f64 {
365        let health = self.health_status.read().await;
366        let healthy_count = self
367            .targets
368            .iter()
369            .filter(|t| {
370                let target_id = format!("{}:{}", t.address, t.port);
371                health.get(&target_id).copied().unwrap_or(true)
372            })
373            .count();
374
375        if healthy_count == 0 {
376            return 0.0;
377        }
378
379        let total = self.total_connections.load(Ordering::Relaxed);
380        total as f64 / healthy_count as f64
381    }
382
383    /// Find the least loaded target when all consistent hash candidates are overloaded
384    async fn find_least_loaded_target(&self) -> Option<usize> {
385        trace!("Finding least loaded target as fallback");
386
387        let health = self.health_status.read().await;
388
389        let mut min_load = u64::MAX;
390        let mut best_target = None;
391
392        for (index, target) in self.targets.iter().enumerate() {
393            let target_id = format!("{}:{}", target.address, target.port);
394            if !health.get(&target_id).copied().unwrap_or(true) {
395                trace!(
396                    target_index = index,
397                    target_id = %target_id,
398                    "Skipping unhealthy target"
399                );
400                continue;
401            }
402
403            let load = self.connection_counts[index].load(Ordering::Relaxed);
404            trace!(
405                target_index = index,
406                target_id = %target_id,
407                load = load,
408                "Evaluating target load"
409            );
410
411            if load < min_load {
412                min_load = load;
413                best_target = Some(index);
414            }
415        }
416
417        if let Some(idx) = best_target {
418            debug!(
419                target_index = idx,
420                load = min_load,
421                "Selected least loaded target"
422            );
423        } else {
424            warn!("No healthy targets found for least loaded selection");
425        }
426
427        best_target
428    }
429
430    /// Extract hash key from request context
431    pub fn extract_hash_key(&self, context: &RequestContext) -> Option<String> {
432        let key = match &self.config.hash_key_extractor {
433            HashKeyExtractor::ClientIp => context.client_ip.map(|ip| ip.to_string()),
434            HashKeyExtractor::Header(name) => context.headers.get(name).cloned(),
435            HashKeyExtractor::Cookie(name) => {
436                // Parse cookie header and extract specific cookie
437                context.headers.get("cookie").and_then(|cookies| {
438                    cookies.split(';').find_map(|cookie| {
439                        let parts: Vec<&str> = cookie.trim().splitn(2, '=').collect();
440                        if parts.len() == 2 && parts[0] == name {
441                            Some(parts[1].to_string())
442                        } else {
443                            None
444                        }
445                    })
446                })
447            }
448            HashKeyExtractor::Custom(extractor) => extractor(context),
449        };
450
451        trace!(
452            extractor = ?self.config.hash_key_extractor,
453            key_found = key.is_some(),
454            "Extracted hash key from request"
455        );
456
457        key
458    }
459
460    /// Track connection acquisition
461    pub fn acquire_connection(&self, target_index: usize) {
462        let count = self.connection_counts[target_index].fetch_add(1, Ordering::Relaxed) + 1;
463        let total = self.total_connections.fetch_add(1, Ordering::Relaxed) + 1;
464        trace!(
465            target_index = target_index,
466            target_connections = count,
467            total_connections = total,
468            "Acquired connection"
469        );
470    }
471
472    /// Track connection release
473    pub fn release_connection(&self, target_index: usize) {
474        let count = self.connection_counts[target_index].fetch_sub(1, Ordering::Relaxed) - 1;
475        let total = self.total_connections.fetch_sub(1, Ordering::Relaxed) - 1;
476        trace!(
477            target_index = target_index,
478            target_connections = count,
479            total_connections = total,
480            "Released connection"
481        );
482    }
483}
484
485#[async_trait]
486impl LoadBalancer for ConsistentHashBalancer {
487    async fn select(&self, context: Option<&RequestContext>) -> SentinelResult<TargetSelection> {
488        trace!(
489            has_context = context.is_some(),
490            "Consistent hash select called"
491        );
492
493        // Extract hash key from context or use random fallback
494        let (hash_key, used_random) = context
495            .and_then(|ctx| self.extract_hash_key(ctx))
496            .map(|k| (k, false))
497            .unwrap_or_else(|| {
498                // Generate random key for requests without proper hash key
499                use rand::Rng;
500                let mut rng = rand::thread_rng();
501                let key = format!("random-{}", rng.gen::<u64>());
502                trace!(random_key = %key, "Generated random hash key (no context key)");
503                (key, true)
504            });
505
506        let target_index = self
507            .find_target(&hash_key)
508            .await
509            .ok_or_else(|| {
510                warn!("No healthy upstream targets available");
511                SentinelError::NoHealthyUpstream
512            })?;
513
514        let target = &self.targets[target_index];
515
516        // Track connection for bounded loads
517        if self.config.bounded_loads {
518            self.acquire_connection(target_index);
519        }
520
521        let current_load = self.connection_counts[target_index].load(Ordering::Relaxed);
522
523        debug!(
524            target = %format!("{}:{}", target.address, target.port),
525            hash_key = %hash_key,
526            target_index = target_index,
527            current_load = current_load,
528            used_random_key = used_random,
529            "Consistent hash selected target"
530        );
531
532        Ok(TargetSelection {
533            address: format!("{}:{}", target.address, target.port),
534            weight: target.weight,
535            metadata: {
536                let mut meta = HashMap::new();
537                meta.insert("hash_key".to_string(), hash_key);
538                meta.insert("target_index".to_string(), target_index.to_string());
539                meta.insert("load".to_string(), current_load.to_string());
540                meta.insert("algorithm".to_string(), "consistent_hash".to_string());
541                meta
542            },
543        })
544    }
545
546    async fn report_health(&self, address: &str, healthy: bool) {
547        trace!(
548            address = %address,
549            healthy = healthy,
550            "Reporting target health"
551        );
552
553        let mut health = self.health_status.write().await;
554        let previous = health.insert(address.to_string(), healthy);
555
556        // Rebuild ring if health status changed
557        if previous != Some(healthy) {
558            info!(
559                address = %address,
560                previous_status = ?previous,
561                new_status = healthy,
562                "Target health changed, rebuilding ring"
563            );
564            drop(health); // Release lock before rebuild
565            self.rebuild_ring().await;
566        }
567    }
568
569    async fn healthy_targets(&self) -> Vec<String> {
570        let health = self.health_status.read().await;
571        let targets: Vec<String> = self.targets
572            .iter()
573            .filter_map(|t| {
574                let target_id = format!("{}:{}", t.address, t.port);
575                if health.get(&target_id).copied().unwrap_or(true) {
576                    Some(target_id)
577                } else {
578                    None
579                }
580            })
581            .collect();
582
583        trace!(
584            total_targets = self.targets.len(),
585            healthy_count = targets.len(),
586            "Retrieved healthy targets"
587        );
588
589        targets
590    }
591
592    /// Release connection when request completes
593    async fn release(&self, selection: &TargetSelection) {
594        if self.config.bounded_loads {
595            if let Some(index_str) = selection.metadata.get("target_index") {
596                if let Ok(index) = index_str.parse::<usize>() {
597                    trace!(
598                        target_index = index,
599                        address = %selection.address,
600                        "Releasing connection for bounded loads"
601                    );
602                    self.release_connection(index);
603                }
604            }
605        }
606    }
607}
608
609#[cfg(test)]
610mod tests {
611    use super::*;
612
613    fn create_test_targets(count: usize) -> Vec<UpstreamTarget> {
614        (0..count)
615            .map(|i| UpstreamTarget {
616                address: format!("10.0.0.{}", i + 1),
617                port: 8080,
618                weight: 100,
619            })
620            .collect()
621    }
622
623    #[tokio::test]
624    async fn test_consistent_distribution() {
625        let targets = create_test_targets(5);
626        let config = ConsistentHashConfig {
627            virtual_nodes: 100,
628            bounded_loads: false,
629            ..Default::default()
630        };
631
632        let balancer = ConsistentHashBalancer::new(targets.clone(), config);
633
634        // Test distribution of 10000 keys
635        let mut distribution = vec![0u64; targets.len()];
636
637        for i in 0..10000 {
638            let context = RequestContext {
639                client_ip: Some(format!("192.168.1.{}:1234", i % 256).parse().unwrap()),
640                headers: HashMap::new(),
641                path: "/".to_string(),
642                method: "GET".to_string(),
643            };
644
645            if let Ok(selection) = balancer.select(Some(&context)).await {
646                if let Some(index_str) = selection.metadata.get("target_index") {
647                    if let Ok(index) = index_str.parse::<usize>() {
648                        distribution[index] += 1;
649                    }
650                }
651            }
652        }
653
654        // Check that distribution is relatively even (within 50% of average)
655        let avg = 10000.0 / targets.len() as f64;
656        for count in distribution {
657            let ratio = count as f64 / avg;
658            assert!(
659                ratio > 0.5 && ratio < 1.5,
660                "Distribution too skewed: {}",
661                ratio
662            );
663        }
664    }
665
666    #[tokio::test]
667    async fn test_bounded_loads() {
668        let targets = create_test_targets(3);
669        let config = ConsistentHashConfig {
670            virtual_nodes: 50,
671            bounded_loads: true,
672            max_load_factor: 1.2,
673            ..Default::default()
674        };
675
676        let balancer = ConsistentHashBalancer::new(targets.clone(), config);
677
678        // Simulate high load on first target
679        balancer.connection_counts[0].store(100, Ordering::Relaxed);
680        balancer.total_connections.store(110, Ordering::Relaxed);
681
682        // New request should avoid overloaded target
683        let context = RequestContext {
684            client_ip: Some("192.168.1.1:1234".parse().unwrap()),
685            headers: HashMap::new(),
686            path: "/".to_string(),
687            method: "GET".to_string(),
688        };
689
690        let selection = balancer.select(Some(&context)).await.unwrap();
691        let index = selection
692            .metadata
693            .get("target_index")
694            .and_then(|s| s.parse::<usize>().ok())
695            .unwrap();
696
697        // Should not select the overloaded target (index 0)
698        assert_ne!(index, 0);
699    }
700
701    #[tokio::test]
702    async fn test_ring_rebuild_on_health_change() {
703        let targets = create_test_targets(3);
704        let config = ConsistentHashConfig::default();
705
706        let balancer = ConsistentHashBalancer::new(targets.clone(), config);
707
708        let initial_generation = balancer.generation.load(Ordering::SeqCst);
709
710        // Mark a target as unhealthy
711        balancer.report_health("10.0.0.1:8080", false).await;
712
713        // Generation should have incremented
714        let new_generation = balancer.generation.load(Ordering::SeqCst);
715        assert_eq!(new_generation, initial_generation + 1);
716
717        // Unhealthy target should not be selected
718        let healthy = balancer.healthy_targets().await;
719        assert_eq!(healthy.len(), 2);
720        assert!(!healthy.contains(&"10.0.0.1:8080".to_string()));
721    }
722}