ruvector_cluster/
shard.rs

1//! Sharding logic for distributed vector storage
2//!
3//! Implements consistent hashing for shard distribution and routing.
4
5use parking_lot::RwLock;
6use serde::{Deserialize, Serialize};
7use std::collections::{BTreeMap, HashMap};
8use std::hash::{Hash, Hasher};
9use std::sync::Arc;
10use tracing::debug;
11
12const VIRTUAL_NODE_COUNT: usize = 150;
13
14/// Consistent hash ring for node assignment
15#[derive(Debug)]
16pub struct ConsistentHashRing {
17    /// Virtual nodes on the ring (hash -> node_id)
18    ring: BTreeMap<u64, String>,
19    /// Real nodes in the cluster
20    nodes: HashMap<String, usize>,
21    /// Replication factor
22    replication_factor: usize,
23}
24
25impl ConsistentHashRing {
26    /// Create a new consistent hash ring
27    pub fn new(replication_factor: usize) -> Self {
28        Self {
29            ring: BTreeMap::new(),
30            nodes: HashMap::new(),
31            replication_factor,
32        }
33    }
34
35    /// Add a node to the ring
36    pub fn add_node(&mut self, node_id: String) {
37        if self.nodes.contains_key(&node_id) {
38            return;
39        }
40
41        // Add virtual nodes for better distribution
42        for i in 0..VIRTUAL_NODE_COUNT {
43            let virtual_key = format!("{}:{}", node_id, i);
44            let hash = Self::hash_key(&virtual_key);
45            self.ring.insert(hash, node_id.clone());
46        }
47
48        self.nodes.insert(node_id, VIRTUAL_NODE_COUNT);
49        debug!(
50            "Added node to hash ring with {} virtual nodes",
51            VIRTUAL_NODE_COUNT
52        );
53    }
54
55    /// Remove a node from the ring
56    pub fn remove_node(&mut self, node_id: &str) {
57        if !self.nodes.contains_key(node_id) {
58            return;
59        }
60
61        // Remove all virtual nodes
62        self.ring.retain(|_, v| v != node_id);
63        self.nodes.remove(node_id);
64        debug!("Removed node from hash ring");
65    }
66
67    /// Get nodes responsible for a key
68    pub fn get_nodes(&self, key: &str, count: usize) -> Vec<String> {
69        if self.ring.is_empty() {
70            return Vec::new();
71        }
72
73        let hash = Self::hash_key(key);
74        let mut nodes = Vec::new();
75        let mut seen = std::collections::HashSet::new();
76
77        // Find the first node on or after the hash
78        for (_, node_id) in self.ring.range(hash..) {
79            if seen.insert(node_id.clone()) {
80                nodes.push(node_id.clone());
81                if nodes.len() >= count {
82                    return nodes;
83                }
84            }
85        }
86
87        // Wrap around to the beginning if needed
88        for (_, node_id) in self.ring.iter() {
89            if seen.insert(node_id.clone()) {
90                nodes.push(node_id.clone());
91                if nodes.len() >= count {
92                    return nodes;
93                }
94            }
95        }
96
97        nodes
98    }
99
100    /// Get the primary node for a key
101    pub fn get_primary_node(&self, key: &str) -> Option<String> {
102        self.get_nodes(key, 1).first().cloned()
103    }
104
105    /// Hash a key to a u64
106    fn hash_key(key: &str) -> u64 {
107        use std::collections::hash_map::DefaultHasher;
108        let mut hasher = DefaultHasher::new();
109        key.hash(&mut hasher);
110        hasher.finish()
111    }
112
113    /// Get the number of real nodes
114    pub fn node_count(&self) -> usize {
115        self.nodes.len()
116    }
117
118    /// List all real nodes
119    pub fn list_nodes(&self) -> Vec<String> {
120        self.nodes.keys().cloned().collect()
121    }
122}
123
124/// Routes queries to the correct shard
125pub struct ShardRouter {
126    /// Total number of shards
127    shard_count: u32,
128    /// Shard assignment cache
129    cache: Arc<RwLock<HashMap<String, u32>>>,
130}
131
132impl ShardRouter {
133    /// Create a new shard router
134    pub fn new(shard_count: u32) -> Self {
135        Self {
136            shard_count,
137            cache: Arc::new(RwLock::new(HashMap::new())),
138        }
139    }
140
141    /// Get the shard ID for a key using jump consistent hashing
142    pub fn get_shard(&self, key: &str) -> u32 {
143        // Check cache first
144        {
145            let cache = self.cache.read();
146            if let Some(&shard_id) = cache.get(key) {
147                return shard_id;
148            }
149        }
150
151        // Calculate using jump consistent hash
152        let shard_id = self.jump_consistent_hash(key, self.shard_count);
153
154        // Update cache
155        {
156            let mut cache = self.cache.write();
157            cache.insert(key.to_string(), shard_id);
158        }
159
160        shard_id
161    }
162
163    /// Jump consistent hash algorithm
164    /// Provides minimal key migration on shard count changes
165    fn jump_consistent_hash(&self, key: &str, num_buckets: u32) -> u32 {
166        use std::collections::hash_map::DefaultHasher;
167
168        let mut hasher = DefaultHasher::new();
169        key.hash(&mut hasher);
170        let mut hash = hasher.finish();
171
172        let mut b: i64 = -1;
173        let mut j: i64 = 0;
174
175        while j < num_buckets as i64 {
176            b = j;
177            hash = hash.wrapping_mul(2862933555777941757).wrapping_add(1);
178            j = ((b.wrapping_add(1) as f64)
179                * ((1i64 << 31) as f64 / ((hash >> 33).wrapping_add(1) as f64)))
180                as i64;
181        }
182
183        b as u32
184    }
185
186    /// Get shard ID for a vector ID
187    pub fn get_shard_for_vector(&self, vector_id: &str) -> u32 {
188        self.get_shard(vector_id)
189    }
190
191    /// Get shard IDs for a range query (may span multiple shards)
192    pub fn get_shards_for_range(&self, _start: &str, _end: &str) -> Vec<u32> {
193        // For range queries, we might need to check multiple shards
194        // For simplicity, return all shards (can be optimized based on key distribution)
195        (0..self.shard_count).collect()
196    }
197
198    /// Clear the routing cache
199    pub fn clear_cache(&self) {
200        let mut cache = self.cache.write();
201        cache.clear();
202    }
203
204    /// Get cache statistics
205    pub fn cache_stats(&self) -> CacheStats {
206        let cache = self.cache.read();
207        CacheStats {
208            entries: cache.len(),
209            shard_count: self.shard_count as usize,
210        }
211    }
212}
213
214/// Cache statistics
215#[derive(Debug, Clone, Serialize, Deserialize)]
216pub struct CacheStats {
217    pub entries: usize,
218    pub shard_count: usize,
219}
220
221/// Shard migration manager
222pub struct ShardMigration {
223    /// Source shard ID
224    pub source_shard: u32,
225    /// Target shard ID
226    pub target_shard: u32,
227    /// Migration progress (0.0 to 1.0)
228    pub progress: f64,
229    /// Keys migrated
230    pub keys_migrated: usize,
231    /// Total keys to migrate
232    pub total_keys: usize,
233}
234
235impl ShardMigration {
236    /// Create a new shard migration
237    pub fn new(source_shard: u32, target_shard: u32, total_keys: usize) -> Self {
238        Self {
239            source_shard,
240            target_shard,
241            progress: 0.0,
242            keys_migrated: 0,
243            total_keys,
244        }
245    }
246
247    /// Update migration progress
248    pub fn update_progress(&mut self, keys_migrated: usize) {
249        self.keys_migrated = keys_migrated;
250        self.progress = if self.total_keys > 0 {
251            keys_migrated as f64 / self.total_keys as f64
252        } else {
253            1.0
254        };
255    }
256
257    /// Check if migration is complete
258    pub fn is_complete(&self) -> bool {
259        self.progress >= 1.0 || self.keys_migrated >= self.total_keys
260    }
261}
262
263/// Load balancer for shard distribution
264pub struct LoadBalancer {
265    /// Shard load statistics (shard_id -> load)
266    loads: Arc<RwLock<HashMap<u32, f64>>>,
267}
268
269impl LoadBalancer {
270    /// Create a new load balancer
271    pub fn new() -> Self {
272        Self {
273            loads: Arc::new(RwLock::new(HashMap::new())),
274        }
275    }
276
277    /// Update load for a shard
278    pub fn update_load(&self, shard_id: u32, load: f64) {
279        let mut loads = self.loads.write();
280        loads.insert(shard_id, load);
281    }
282
283    /// Get load for a shard
284    pub fn get_load(&self, shard_id: u32) -> f64 {
285        let loads = self.loads.read();
286        loads.get(&shard_id).copied().unwrap_or(0.0)
287    }
288
289    /// Get the least loaded shard
290    pub fn get_least_loaded_shard(&self, shard_ids: &[u32]) -> Option<u32> {
291        let loads = self.loads.read();
292
293        shard_ids
294            .iter()
295            .min_by(|&&a, &&b| {
296                let load_a = loads.get(&a).copied().unwrap_or(0.0);
297                let load_b = loads.get(&b).copied().unwrap_or(0.0);
298                load_a
299                    .partial_cmp(&load_b)
300                    .unwrap_or(std::cmp::Ordering::Equal)
301            })
302            .copied()
303    }
304
305    /// Get load statistics
306    pub fn get_stats(&self) -> LoadStats {
307        let loads = self.loads.read();
308
309        let total: f64 = loads.values().sum();
310        let count = loads.len();
311        let avg = if count > 0 { total / count as f64 } else { 0.0 };
312
313        let max = loads.values().copied().fold(f64::NEG_INFINITY, f64::max);
314        let min = loads.values().copied().fold(f64::INFINITY, f64::min);
315
316        LoadStats {
317            total_load: total,
318            avg_load: avg,
319            max_load: if max.is_finite() { max } else { 0.0 },
320            min_load: if min.is_finite() { min } else { 0.0 },
321            shard_count: count,
322        }
323    }
324}
325
326impl Default for LoadBalancer {
327    fn default() -> Self {
328        Self::new()
329    }
330}
331
332/// Load statistics
333#[derive(Debug, Clone, Serialize, Deserialize)]
334pub struct LoadStats {
335    pub total_load: f64,
336    pub avg_load: f64,
337    pub max_load: f64,
338    pub min_load: f64,
339    pub shard_count: usize,
340}
341
342#[cfg(test)]
343mod tests {
344    use super::*;
345
346    #[test]
347    fn test_consistent_hash_ring() {
348        let mut ring = ConsistentHashRing::new(3);
349
350        ring.add_node("node1".to_string());
351        ring.add_node("node2".to_string());
352        ring.add_node("node3".to_string());
353
354        assert_eq!(ring.node_count(), 3);
355
356        let nodes = ring.get_nodes("test-key", 3);
357        assert_eq!(nodes.len(), 3);
358
359        // Test primary node selection
360        let primary = ring.get_primary_node("test-key");
361        assert!(primary.is_some());
362    }
363
364    #[test]
365    fn test_consistent_hashing_distribution() {
366        let mut ring = ConsistentHashRing::new(3);
367
368        ring.add_node("node1".to_string());
369        ring.add_node("node2".to_string());
370        ring.add_node("node3".to_string());
371
372        let mut distribution: HashMap<String, usize> = HashMap::new();
373
374        // Test distribution across many keys
375        for i in 0..1000 {
376            let key = format!("key{}", i);
377            if let Some(node) = ring.get_primary_node(&key) {
378                *distribution.entry(node).or_insert(0) += 1;
379            }
380        }
381
382        // Each node should get roughly 1/3 of the keys (within 20% tolerance)
383        for count in distribution.values() {
384            let ratio = *count as f64 / 1000.0;
385            assert!(ratio > 0.2 && ratio < 0.5, "Distribution ratio: {}", ratio);
386        }
387    }
388
389    #[test]
390    fn test_shard_router() {
391        let router = ShardRouter::new(16);
392
393        let shard1 = router.get_shard("test-key-1");
394        let shard2 = router.get_shard("test-key-1"); // Should be cached
395
396        assert_eq!(shard1, shard2);
397        assert!(shard1 < 16);
398
399        let stats = router.cache_stats();
400        assert_eq!(stats.entries, 1);
401    }
402
403    #[test]
404    fn test_jump_consistent_hash() {
405        let router = ShardRouter::new(10);
406
407        // Same key should always map to same shard
408        let shard1 = router.get_shard("consistent-key");
409        let shard2 = router.get_shard("consistent-key");
410
411        assert_eq!(shard1, shard2);
412    }
413
414    #[test]
415    fn test_shard_migration() {
416        let mut migration = ShardMigration::new(0, 1, 100);
417
418        assert!(!migration.is_complete());
419        assert_eq!(migration.progress, 0.0);
420
421        migration.update_progress(50);
422        assert_eq!(migration.progress, 0.5);
423
424        migration.update_progress(100);
425        assert!(migration.is_complete());
426    }
427
428    #[test]
429    fn test_load_balancer() {
430        let balancer = LoadBalancer::new();
431
432        balancer.update_load(0, 0.5);
433        balancer.update_load(1, 0.8);
434        balancer.update_load(2, 0.3);
435
436        let least_loaded = balancer.get_least_loaded_shard(&[0, 1, 2]);
437        assert_eq!(least_loaded, Some(2));
438
439        let stats = balancer.get_stats();
440        assert_eq!(stats.shard_count, 3);
441        assert!(stats.avg_load > 0.0);
442    }
443}