1use parking_lot::RwLock;
6use serde::{Deserialize, Serialize};
7use std::collections::{BTreeMap, HashMap};
8use std::hash::{Hash, Hasher};
9use std::sync::Arc;
10use tracing::debug;
11
12const VIRTUAL_NODE_COUNT: usize = 150;
13
14#[derive(Debug)]
16pub struct ConsistentHashRing {
17 ring: BTreeMap<u64, String>,
19 nodes: HashMap<String, usize>,
21 replication_factor: usize,
23}
24
25impl ConsistentHashRing {
26 pub fn new(replication_factor: usize) -> Self {
28 Self {
29 ring: BTreeMap::new(),
30 nodes: HashMap::new(),
31 replication_factor,
32 }
33 }
34
35 pub fn add_node(&mut self, node_id: String) {
37 if self.nodes.contains_key(&node_id) {
38 return;
39 }
40
41 for i in 0..VIRTUAL_NODE_COUNT {
43 let virtual_key = format!("{}:{}", node_id, i);
44 let hash = Self::hash_key(&virtual_key);
45 self.ring.insert(hash, node_id.clone());
46 }
47
48 self.nodes.insert(node_id, VIRTUAL_NODE_COUNT);
49 debug!(
50 "Added node to hash ring with {} virtual nodes",
51 VIRTUAL_NODE_COUNT
52 );
53 }
54
55 pub fn remove_node(&mut self, node_id: &str) {
57 if !self.nodes.contains_key(node_id) {
58 return;
59 }
60
61 self.ring.retain(|_, v| v != node_id);
63 self.nodes.remove(node_id);
64 debug!("Removed node from hash ring");
65 }
66
67 pub fn get_nodes(&self, key: &str, count: usize) -> Vec<String> {
69 if self.ring.is_empty() {
70 return Vec::new();
71 }
72
73 let hash = Self::hash_key(key);
74 let mut nodes = Vec::new();
75 let mut seen = std::collections::HashSet::new();
76
77 for (_, node_id) in self.ring.range(hash..) {
79 if seen.insert(node_id.clone()) {
80 nodes.push(node_id.clone());
81 if nodes.len() >= count {
82 return nodes;
83 }
84 }
85 }
86
87 for (_, node_id) in self.ring.iter() {
89 if seen.insert(node_id.clone()) {
90 nodes.push(node_id.clone());
91 if nodes.len() >= count {
92 return nodes;
93 }
94 }
95 }
96
97 nodes
98 }
99
100 pub fn get_primary_node(&self, key: &str) -> Option<String> {
102 self.get_nodes(key, 1).first().cloned()
103 }
104
105 fn hash_key(key: &str) -> u64 {
107 use std::collections::hash_map::DefaultHasher;
108 let mut hasher = DefaultHasher::new();
109 key.hash(&mut hasher);
110 hasher.finish()
111 }
112
113 pub fn node_count(&self) -> usize {
115 self.nodes.len()
116 }
117
118 pub fn list_nodes(&self) -> Vec<String> {
120 self.nodes.keys().cloned().collect()
121 }
122}
123
124pub struct ShardRouter {
126 shard_count: u32,
128 cache: Arc<RwLock<HashMap<String, u32>>>,
130}
131
132impl ShardRouter {
133 pub fn new(shard_count: u32) -> Self {
135 Self {
136 shard_count,
137 cache: Arc::new(RwLock::new(HashMap::new())),
138 }
139 }
140
141 pub fn get_shard(&self, key: &str) -> u32 {
143 {
145 let cache = self.cache.read();
146 if let Some(&shard_id) = cache.get(key) {
147 return shard_id;
148 }
149 }
150
151 let shard_id = self.jump_consistent_hash(key, self.shard_count);
153
154 {
156 let mut cache = self.cache.write();
157 cache.insert(key.to_string(), shard_id);
158 }
159
160 shard_id
161 }
162
163 fn jump_consistent_hash(&self, key: &str, num_buckets: u32) -> u32 {
166 use std::collections::hash_map::DefaultHasher;
167
168 let mut hasher = DefaultHasher::new();
169 key.hash(&mut hasher);
170 let mut hash = hasher.finish();
171
172 let mut b: i64 = -1;
173 let mut j: i64 = 0;
174
175 while j < num_buckets as i64 {
176 b = j;
177 hash = hash.wrapping_mul(2862933555777941757).wrapping_add(1);
178 j = ((b.wrapping_add(1) as f64)
179 * ((1i64 << 31) as f64 / ((hash >> 33).wrapping_add(1) as f64)))
180 as i64;
181 }
182
183 b as u32
184 }
185
186 pub fn get_shard_for_vector(&self, vector_id: &str) -> u32 {
188 self.get_shard(vector_id)
189 }
190
191 pub fn get_shards_for_range(&self, _start: &str, _end: &str) -> Vec<u32> {
193 (0..self.shard_count).collect()
196 }
197
198 pub fn clear_cache(&self) {
200 let mut cache = self.cache.write();
201 cache.clear();
202 }
203
204 pub fn cache_stats(&self) -> CacheStats {
206 let cache = self.cache.read();
207 CacheStats {
208 entries: cache.len(),
209 shard_count: self.shard_count as usize,
210 }
211 }
212}
213
214#[derive(Debug, Clone, Serialize, Deserialize)]
216pub struct CacheStats {
217 pub entries: usize,
218 pub shard_count: usize,
219}
220
221pub struct ShardMigration {
223 pub source_shard: u32,
225 pub target_shard: u32,
227 pub progress: f64,
229 pub keys_migrated: usize,
231 pub total_keys: usize,
233}
234
235impl ShardMigration {
236 pub fn new(source_shard: u32, target_shard: u32, total_keys: usize) -> Self {
238 Self {
239 source_shard,
240 target_shard,
241 progress: 0.0,
242 keys_migrated: 0,
243 total_keys,
244 }
245 }
246
247 pub fn update_progress(&mut self, keys_migrated: usize) {
249 self.keys_migrated = keys_migrated;
250 self.progress = if self.total_keys > 0 {
251 keys_migrated as f64 / self.total_keys as f64
252 } else {
253 1.0
254 };
255 }
256
257 pub fn is_complete(&self) -> bool {
259 self.progress >= 1.0 || self.keys_migrated >= self.total_keys
260 }
261}
262
263pub struct LoadBalancer {
265 loads: Arc<RwLock<HashMap<u32, f64>>>,
267}
268
269impl LoadBalancer {
270 pub fn new() -> Self {
272 Self {
273 loads: Arc::new(RwLock::new(HashMap::new())),
274 }
275 }
276
277 pub fn update_load(&self, shard_id: u32, load: f64) {
279 let mut loads = self.loads.write();
280 loads.insert(shard_id, load);
281 }
282
283 pub fn get_load(&self, shard_id: u32) -> f64 {
285 let loads = self.loads.read();
286 loads.get(&shard_id).copied().unwrap_or(0.0)
287 }
288
289 pub fn get_least_loaded_shard(&self, shard_ids: &[u32]) -> Option<u32> {
291 let loads = self.loads.read();
292
293 shard_ids
294 .iter()
295 .min_by(|&&a, &&b| {
296 let load_a = loads.get(&a).copied().unwrap_or(0.0);
297 let load_b = loads.get(&b).copied().unwrap_or(0.0);
298 load_a
299 .partial_cmp(&load_b)
300 .unwrap_or(std::cmp::Ordering::Equal)
301 })
302 .copied()
303 }
304
305 pub fn get_stats(&self) -> LoadStats {
307 let loads = self.loads.read();
308
309 let total: f64 = loads.values().sum();
310 let count = loads.len();
311 let avg = if count > 0 { total / count as f64 } else { 0.0 };
312
313 let max = loads.values().copied().fold(f64::NEG_INFINITY, f64::max);
314 let min = loads.values().copied().fold(f64::INFINITY, f64::min);
315
316 LoadStats {
317 total_load: total,
318 avg_load: avg,
319 max_load: if max.is_finite() { max } else { 0.0 },
320 min_load: if min.is_finite() { min } else { 0.0 },
321 shard_count: count,
322 }
323 }
324}
325
326impl Default for LoadBalancer {
327 fn default() -> Self {
328 Self::new()
329 }
330}
331
332#[derive(Debug, Clone, Serialize, Deserialize)]
334pub struct LoadStats {
335 pub total_load: f64,
336 pub avg_load: f64,
337 pub max_load: f64,
338 pub min_load: f64,
339 pub shard_count: usize,
340}
341
342#[cfg(test)]
343mod tests {
344 use super::*;
345
346 #[test]
347 fn test_consistent_hash_ring() {
348 let mut ring = ConsistentHashRing::new(3);
349
350 ring.add_node("node1".to_string());
351 ring.add_node("node2".to_string());
352 ring.add_node("node3".to_string());
353
354 assert_eq!(ring.node_count(), 3);
355
356 let nodes = ring.get_nodes("test-key", 3);
357 assert_eq!(nodes.len(), 3);
358
359 let primary = ring.get_primary_node("test-key");
361 assert!(primary.is_some());
362 }
363
364 #[test]
365 fn test_consistent_hashing_distribution() {
366 let mut ring = ConsistentHashRing::new(3);
367
368 ring.add_node("node1".to_string());
369 ring.add_node("node2".to_string());
370 ring.add_node("node3".to_string());
371
372 let mut distribution: HashMap<String, usize> = HashMap::new();
373
374 for i in 0..1000 {
376 let key = format!("key{}", i);
377 if let Some(node) = ring.get_primary_node(&key) {
378 *distribution.entry(node).or_insert(0) += 1;
379 }
380 }
381
382 for count in distribution.values() {
384 let ratio = *count as f64 / 1000.0;
385 assert!(ratio > 0.2 && ratio < 0.5, "Distribution ratio: {}", ratio);
386 }
387 }
388
389 #[test]
390 fn test_shard_router() {
391 let router = ShardRouter::new(16);
392
393 let shard1 = router.get_shard("test-key-1");
394 let shard2 = router.get_shard("test-key-1"); assert_eq!(shard1, shard2);
397 assert!(shard1 < 16);
398
399 let stats = router.cache_stats();
400 assert_eq!(stats.entries, 1);
401 }
402
403 #[test]
404 fn test_jump_consistent_hash() {
405 let router = ShardRouter::new(10);
406
407 let shard1 = router.get_shard("consistent-key");
409 let shard2 = router.get_shard("consistent-key");
410
411 assert_eq!(shard1, shard2);
412 }
413
414 #[test]
415 fn test_shard_migration() {
416 let mut migration = ShardMigration::new(0, 1, 100);
417
418 assert!(!migration.is_complete());
419 assert_eq!(migration.progress, 0.0);
420
421 migration.update_progress(50);
422 assert_eq!(migration.progress, 0.5);
423
424 migration.update_progress(100);
425 assert!(migration.is_complete());
426 }
427
428 #[test]
429 fn test_load_balancer() {
430 let balancer = LoadBalancer::new();
431
432 balancer.update_load(0, 0.5);
433 balancer.update_load(1, 0.8);
434 balancer.update_load(2, 0.3);
435
436 let least_loaded = balancer.get_least_loaded_shard(&[0, 1, 2]);
437 assert_eq!(least_loaded, Some(2));
438
439 let stats = balancer.get_stats();
440 assert_eq!(stats.shard_count, 3);
441 assert!(stats.avg_load > 0.0);
442 }
443}