sentinel_proxy/upstream/
consistent_hash.rs

1use murmur3::murmur3_32;
2use std::collections::{BTreeMap, HashMap};
3use std::hash::{Hash, Hasher};
4use std::io::Cursor;
5use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
6use std::sync::Arc;
7use tokio::sync::RwLock;
8use xxhash_rust::xxh3::Xxh3;
9
10use super::{LoadBalancer, RequestContext, TargetSelection, UpstreamTarget};
11use sentinel_common::errors::{SentinelError, SentinelResult};
12use async_trait::async_trait;
13use tracing::{debug, info, warn};
14
15/// Hash function types supported by the consistent hash balancer
16#[derive(Debug, Clone, Copy)]
17pub enum HashFunction {
18    Xxh3,
19    Murmur3,
20    DefaultHasher,
21}
22
23/// Configuration for consistent hashing
24#[derive(Debug, Clone)]
25pub struct ConsistentHashConfig {
26    /// Number of virtual nodes per real target
27    pub virtual_nodes: usize,
28    /// Hash function to use
29    pub hash_function: HashFunction,
30    /// Enable bounded loads to prevent overload
31    pub bounded_loads: bool,
32    /// Maximum load factor (1.0 = average load, 1.25 = 25% above average)
33    pub max_load_factor: f64,
34    /// Key extraction function (e.g., from headers, cookies)
35    pub hash_key_extractor: HashKeyExtractor,
36}
37
38impl Default for ConsistentHashConfig {
39    fn default() -> Self {
40        Self {
41            virtual_nodes: 150,
42            hash_function: HashFunction::Xxh3,
43            bounded_loads: true,
44            max_load_factor: 1.25,
45            hash_key_extractor: HashKeyExtractor::ClientIp,
46        }
47    }
48}
49
50/// Defines how to extract the hash key from a request
51#[derive(Clone)]
52pub enum HashKeyExtractor {
53    ClientIp,
54    Header(String),
55    Cookie(String),
56    Custom(Arc<dyn Fn(&RequestContext) -> Option<String> + Send + Sync>),
57}
58
59impl std::fmt::Debug for HashKeyExtractor {
60    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
61        match self {
62            Self::ClientIp => write!(f, "ClientIp"),
63            Self::Header(h) => write!(f, "Header({})", h),
64            Self::Cookie(c) => write!(f, "Cookie({})", c),
65            Self::Custom(_) => write!(f, "Custom"),
66        }
67    }
68}
69
70/// Virtual node in the consistent hash ring
71#[derive(Debug, Clone)]
72struct VirtualNode {
73    /// Hash value of this virtual node
74    hash: u64,
75    /// Index of the real target this virtual node represents
76    target_index: usize,
77    /// Virtual node number for this target
78    virtual_index: usize,
79}
80
81/// Consistent hash load balancer with virtual nodes and bounded loads
82pub struct ConsistentHashBalancer {
83    /// Configuration
84    config: ConsistentHashConfig,
85    /// All upstream targets
86    targets: Vec<UpstreamTarget>,
87    /// Hash ring (sorted by hash value)
88    ring: Arc<RwLock<BTreeMap<u64, VirtualNode>>>,
89    /// Target health status
90    health_status: Arc<RwLock<HashMap<String, bool>>>,
91    /// Active connection count per target (for bounded loads)
92    connection_counts: Vec<Arc<AtomicU64>>,
93    /// Total active connections
94    total_connections: Arc<AtomicU64>,
95    /// Cache for recent hash lookups (hash -> target_index)
96    lookup_cache: Arc<RwLock<HashMap<u64, usize>>>,
97    /// Generation counter for detecting ring changes
98    generation: Arc<AtomicUsize>,
99}
100
101impl ConsistentHashBalancer {
102    pub fn new(targets: Vec<UpstreamTarget>, config: ConsistentHashConfig) -> Self {
103        let connection_counts = targets
104            .iter()
105            .map(|_| Arc::new(AtomicU64::new(0)))
106            .collect();
107
108        let balancer = Self {
109            config,
110            targets: targets.clone(),
111            ring: Arc::new(RwLock::new(BTreeMap::new())),
112            health_status: Arc::new(RwLock::new(HashMap::new())),
113            connection_counts,
114            total_connections: Arc::new(AtomicU64::new(0)),
115            lookup_cache: Arc::new(RwLock::new(HashMap::with_capacity(1000))),
116            generation: Arc::new(AtomicUsize::new(0)),
117        };
118
119        // Build initial ring
120        tokio::task::block_in_place(|| {
121            tokio::runtime::Handle::current().block_on(balancer.rebuild_ring());
122        });
123
124        balancer
125    }
126
127    /// Rebuild the hash ring based on current targets and health
128    async fn rebuild_ring(&self) {
129        let mut new_ring = BTreeMap::new();
130        let health = self.health_status.read().await;
131
132        for (index, target) in self.targets.iter().enumerate() {
133            let target_id = format!("{}:{}", target.address, target.port);
134            let is_healthy = health.get(&target_id).copied().unwrap_or(true);
135
136            if !is_healthy {
137                continue;
138            }
139
140            // Add virtual nodes for this target
141            for vnode in 0..self.config.virtual_nodes {
142                let vnode_key = format!("{}-vnode-{}", target_id, vnode);
143                let hash = self.hash_key(&vnode_key);
144
145                new_ring.insert(
146                    hash,
147                    VirtualNode {
148                        hash,
149                        target_index: index,
150                        virtual_index: vnode,
151                    },
152                );
153            }
154        }
155
156        if new_ring.is_empty() {
157            warn!("No healthy targets available for consistent hash ring");
158        } else {
159            info!(
160                "Rebuilt consistent hash ring with {} virtual nodes from {} healthy targets",
161                new_ring.len(),
162                new_ring
163                    .values()
164                    .map(|n| n.target_index)
165                    .collect::<std::collections::HashSet<_>>()
166                    .len()
167            );
168        }
169
170        *self.ring.write().await = new_ring;
171
172        // Clear cache on ring change
173        self.lookup_cache.write().await.clear();
174        self.generation.fetch_add(1, Ordering::SeqCst);
175    }
176
177    /// Hash a key using the configured hash function
178    fn hash_key(&self, key: &str) -> u64 {
179        match self.config.hash_function {
180            HashFunction::Xxh3 => {
181                let mut hasher = Xxh3::new();
182                hasher.update(key.as_bytes());
183                hasher.digest()
184            }
185            HashFunction::Murmur3 => {
186                let mut cursor = Cursor::new(key.as_bytes());
187                murmur3_32(&mut cursor, 0).unwrap_or(0) as u64
188            }
189            HashFunction::DefaultHasher => {
190                use std::collections::hash_map::DefaultHasher;
191                let mut hasher = DefaultHasher::new();
192                key.hash(&mut hasher);
193                hasher.finish()
194            }
195        }
196    }
197
198    /// Find target using consistent hashing with optional bounded loads
199    async fn find_target(&self, hash_key: &str) -> Option<usize> {
200        let key_hash = self.hash_key(hash_key);
201
202        // Check cache first
203        {
204            let cache = self.lookup_cache.read().await;
205            if let Some(&target_index) = cache.get(&key_hash) {
206                // Verify target is still healthy
207                let health = self.health_status.read().await;
208                let target = &self.targets[target_index];
209                let target_id = format!("{}:{}", target.address, target.port);
210                if health.get(&target_id).copied().unwrap_or(true) {
211                    return Some(target_index);
212                }
213            }
214        }
215
216        let ring = self.ring.read().await;
217
218        if ring.is_empty() {
219            return None;
220        }
221
222        // Find the first virtual node with hash >= key_hash
223        let candidates = if let Some((&_node_hash, vnode)) = ring.range(key_hash..).next() {
224            vec![vnode.clone()]
225        } else {
226            // Wrap around to the first node
227            ring.iter()
228                .next()
229                .map(|(_, vnode)| vec![vnode.clone()])
230                .unwrap_or_default()
231        };
232
233        // If bounded loads is disabled, return the first candidate
234        if !self.config.bounded_loads {
235            let target_index = candidates.first().map(|n| n.target_index);
236
237            // Update cache
238            if let Some(idx) = target_index {
239                self.lookup_cache.write().await.insert(key_hash, idx);
240            }
241
242            return target_index;
243        }
244
245        // Bounded loads: check if target is overloaded
246        let avg_load = self.calculate_average_load().await;
247        let max_load = (avg_load * self.config.max_load_factor) as u64;
248
249        // Try candidates in order until we find one that's not overloaded
250        for vnode in candidates {
251            let current_load = self.connection_counts[vnode.target_index].load(Ordering::Relaxed);
252
253            if current_load <= max_load {
254                // Update cache
255                self.lookup_cache
256                    .write()
257                    .await
258                    .insert(key_hash, vnode.target_index);
259                return Some(vnode.target_index);
260            }
261        }
262
263        // If all candidates are overloaded, find least loaded target
264        self.find_least_loaded_target().await
265    }
266
267    /// Calculate average load across all healthy targets
268    async fn calculate_average_load(&self) -> f64 {
269        let health = self.health_status.read().await;
270        let healthy_count = self
271            .targets
272            .iter()
273            .filter(|t| {
274                let target_id = format!("{}:{}", t.address, t.port);
275                health.get(&target_id).copied().unwrap_or(true)
276            })
277            .count();
278
279        if healthy_count == 0 {
280            return 0.0;
281        }
282
283        let total = self.total_connections.load(Ordering::Relaxed);
284        total as f64 / healthy_count as f64
285    }
286
287    /// Find the least loaded target when all consistent hash candidates are overloaded
288    async fn find_least_loaded_target(&self) -> Option<usize> {
289        let health = self.health_status.read().await;
290
291        let mut min_load = u64::MAX;
292        let mut best_target = None;
293
294        for (index, target) in self.targets.iter().enumerate() {
295            let target_id = format!("{}:{}", target.address, target.port);
296            if !health.get(&target_id).copied().unwrap_or(true) {
297                continue;
298            }
299
300            let load = self.connection_counts[index].load(Ordering::Relaxed);
301            if load < min_load {
302                min_load = load;
303                best_target = Some(index);
304            }
305        }
306
307        best_target
308    }
309
310    /// Extract hash key from request context
311    pub fn extract_hash_key(&self, context: &RequestContext) -> Option<String> {
312        match &self.config.hash_key_extractor {
313            HashKeyExtractor::ClientIp => context.client_ip.map(|ip| ip.to_string()),
314            HashKeyExtractor::Header(name) => context.headers.get(name).cloned(),
315            HashKeyExtractor::Cookie(name) => {
316                // Parse cookie header and extract specific cookie
317                context.headers.get("cookie").and_then(|cookies| {
318                    cookies.split(';').find_map(|cookie| {
319                        let parts: Vec<&str> = cookie.trim().splitn(2, '=').collect();
320                        if parts.len() == 2 && parts[0] == name {
321                            Some(parts[1].to_string())
322                        } else {
323                            None
324                        }
325                    })
326                })
327            }
328            HashKeyExtractor::Custom(extractor) => extractor(context),
329        }
330    }
331
332    /// Track connection acquisition
333    pub fn acquire_connection(&self, target_index: usize) {
334        self.connection_counts[target_index].fetch_add(1, Ordering::Relaxed);
335        self.total_connections.fetch_add(1, Ordering::Relaxed);
336    }
337
338    /// Track connection release
339    pub fn release_connection(&self, target_index: usize) {
340        self.connection_counts[target_index].fetch_sub(1, Ordering::Relaxed);
341        self.total_connections.fetch_sub(1, Ordering::Relaxed);
342    }
343}
344
345#[async_trait]
346impl LoadBalancer for ConsistentHashBalancer {
347    async fn select(&self, context: Option<&RequestContext>) -> SentinelResult<TargetSelection> {
348        // Extract hash key from context or use random fallback
349        let hash_key = context
350            .and_then(|ctx| self.extract_hash_key(ctx))
351            .unwrap_or_else(|| {
352                // Generate random key for requests without proper hash key
353                use rand::Rng;
354                let mut rng = rand::thread_rng();
355                format!("random-{}", rng.gen::<u64>())
356            });
357
358        let target_index = self
359            .find_target(&hash_key)
360            .await
361            .ok_or_else(|| SentinelError::NoHealthyUpstream)?;
362
363        let target = &self.targets[target_index];
364
365        // Track connection for bounded loads
366        if self.config.bounded_loads {
367            self.acquire_connection(target_index);
368        }
369
370        debug!(
371            "Selected target {} for hash key {} (index: {})",
372            format!("{}:{}", target.address, target.port),
373            hash_key,
374            target_index
375        );
376
377        Ok(TargetSelection {
378            address: format!("{}:{}", target.address, target.port),
379            weight: target.weight,
380            metadata: {
381                let mut meta = HashMap::new();
382                meta.insert("hash_key".to_string(), hash_key);
383                meta.insert("target_index".to_string(), target_index.to_string());
384                meta.insert(
385                    "load".to_string(),
386                    self.connection_counts[target_index]
387                        .load(Ordering::Relaxed)
388                        .to_string(),
389                );
390                meta
391            },
392        })
393    }
394
395    async fn report_health(&self, address: &str, healthy: bool) {
396        let mut health = self.health_status.write().await;
397        let previous = health.insert(address.to_string(), healthy);
398
399        // Rebuild ring if health status changed
400        if previous != Some(healthy) {
401            info!(
402                "Target {} health changed from {:?} to {}",
403                address, previous, healthy
404            );
405            self.rebuild_ring().await;
406        }
407    }
408
409    async fn healthy_targets(&self) -> Vec<String> {
410        let health = self.health_status.read().await;
411        self.targets
412            .iter()
413            .filter_map(|t| {
414                let target_id = format!("{}:{}", t.address, t.port);
415                if health.get(&target_id).copied().unwrap_or(true) {
416                    Some(target_id)
417                } else {
418                    None
419                }
420            })
421            .collect()
422    }
423
424    /// Release connection when request completes
425    async fn release(&self, selection: &TargetSelection) {
426        if self.config.bounded_loads {
427            if let Some(index_str) = selection.metadata.get("target_index") {
428                if let Ok(index) = index_str.parse::<usize>() {
429                    self.release_connection(index);
430                }
431            }
432        }
433    }
434}
435
436#[cfg(test)]
437mod tests {
438    use super::*;
439
440    fn create_test_targets(count: usize) -> Vec<UpstreamTarget> {
441        (0..count)
442            .map(|i| UpstreamTarget {
443                address: format!("10.0.0.{}", i + 1),
444                port: 8080,
445                weight: 100,
446            })
447            .collect()
448    }
449
450    #[tokio::test]
451    async fn test_consistent_distribution() {
452        let targets = create_test_targets(5);
453        let config = ConsistentHashConfig {
454            virtual_nodes: 100,
455            bounded_loads: false,
456            ..Default::default()
457        };
458
459        let balancer = ConsistentHashBalancer::new(targets.clone(), config);
460
461        // Test distribution of 10000 keys
462        let mut distribution = vec![0u64; targets.len()];
463
464        for i in 0..10000 {
465            let context = RequestContext {
466                client_ip: Some(format!("192.168.1.{}:1234", i % 256).parse().unwrap()),
467                headers: HashMap::new(),
468                path: "/".to_string(),
469                method: "GET".to_string(),
470            };
471
472            if let Ok(selection) = balancer.select(Some(&context)).await {
473                if let Some(index_str) = selection.metadata.get("target_index") {
474                    if let Ok(index) = index_str.parse::<usize>() {
475                        distribution[index] += 1;
476                    }
477                }
478            }
479        }
480
481        // Check that distribution is relatively even (within 50% of average)
482        let avg = 10000.0 / targets.len() as f64;
483        for count in distribution {
484            let ratio = count as f64 / avg;
485            assert!(
486                ratio > 0.5 && ratio < 1.5,
487                "Distribution too skewed: {}",
488                ratio
489            );
490        }
491    }
492
493    #[tokio::test]
494    async fn test_bounded_loads() {
495        let targets = create_test_targets(3);
496        let config = ConsistentHashConfig {
497            virtual_nodes: 50,
498            bounded_loads: true,
499            max_load_factor: 1.2,
500            ..Default::default()
501        };
502
503        let balancer = ConsistentHashBalancer::new(targets.clone(), config);
504
505        // Simulate high load on first target
506        balancer.connection_counts[0].store(100, Ordering::Relaxed);
507        balancer.total_connections.store(110, Ordering::Relaxed);
508
509        // New request should avoid overloaded target
510        let context = RequestContext {
511            client_ip: Some("192.168.1.1:1234".parse().unwrap()),
512            headers: HashMap::new(),
513            path: "/".to_string(),
514            method: "GET".to_string(),
515        };
516
517        let selection = balancer.select(Some(&context)).await.unwrap();
518        let index = selection
519            .metadata
520            .get("target_index")
521            .and_then(|s| s.parse::<usize>().ok())
522            .unwrap();
523
524        // Should not select the overloaded target (index 0)
525        assert_ne!(index, 0);
526    }
527
528    #[tokio::test]
529    async fn test_ring_rebuild_on_health_change() {
530        let targets = create_test_targets(3);
531        let config = ConsistentHashConfig::default();
532
533        let balancer = ConsistentHashBalancer::new(targets.clone(), config);
534
535        let initial_generation = balancer.generation.load(Ordering::SeqCst);
536
537        // Mark a target as unhealthy
538        balancer.report_health("10.0.0.1:8080", false).await;
539
540        // Generation should have incremented
541        let new_generation = balancer.generation.load(Ordering::SeqCst);
542        assert_eq!(new_generation, initial_generation + 1);
543
544        // Unhealthy target should not be selected
545        let healthy = balancer.healthy_targets().await;
546        assert_eq!(healthy.len(), 2);
547        assert!(!healthy.contains(&"10.0.0.1:8080".to_string()));
548    }
549}