1#[cfg(feature = "async")]
88pub mod raft;
89
90use anyhow::{anyhow, Result};
91use serde::{Deserialize, Serialize};
92use std::collections::{HashMap, HashSet};
93use std::sync::Arc;
94use std::time::{Duration, Instant};
95
96#[cfg(feature = "async")]
97use tokio::sync::RwLock;
98
99#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
101pub enum ShardingStrategy {
102 Hash,
104
105 ConsistentHash,
107
108 Range,
110
111 Random,
113}
114
115#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
117pub enum ConsistencyLevel {
118 One,
120
121 Quorum,
123
124 All,
126}
127
128#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
130pub enum ReplicationStrategy {
131 PrimaryBackup,
133
134 MultiMaster,
136
137 Chain,
139}
140
141#[derive(Debug, Clone)]
143pub struct DistributedConfig {
144 pub num_shards: usize,
146
147 pub replication_factor: usize,
149
150 pub sharding_strategy: ShardingStrategy,
152
153 pub consistency_level: ConsistencyLevel,
155
156 pub replication_strategy: ReplicationStrategy,
158
159 pub heartbeat_interval_ms: u64,
161
162 pub failure_timeout_ms: u64,
164
165 pub auto_rebalance: bool,
167
168 pub max_shard_size_bytes: usize,
170}
171
172impl Default for DistributedConfig {
173 fn default() -> Self {
174 Self {
175 num_shards: 8,
176 replication_factor: 3,
177 sharding_strategy: ShardingStrategy::ConsistentHash,
178 consistency_level: ConsistencyLevel::Quorum,
179 replication_strategy: ReplicationStrategy::PrimaryBackup,
180 heartbeat_interval_ms: 1000,
181 failure_timeout_ms: 5000,
182 auto_rebalance: true,
183 max_shard_size_bytes: 100 * 1024 * 1024, }
185 }
186}
187
188impl DistributedConfig {
189 pub fn new() -> Self {
190 Self::default()
191 }
192
193 pub fn with_num_shards(mut self, num: usize) -> Self {
194 if num == 0 {
196 panic!("Number of shards must be at least 1, got 0");
197 }
198 self.num_shards = num;
199 self
200 }
201
202 pub fn with_replication_factor(mut self, factor: usize) -> Self {
203 self.replication_factor = factor;
204 self
205 }
206
207 pub fn with_sharding_strategy(mut self, strategy: ShardingStrategy) -> Self {
208 self.sharding_strategy = strategy;
209 self
210 }
211
212 pub fn with_consistency(mut self, level: ConsistencyLevel) -> Self {
213 self.consistency_level = level;
214 self
215 }
216
217 pub fn with_replication_strategy(mut self, strategy: ReplicationStrategy) -> Self {
218 self.replication_strategy = strategy;
219 self
220 }
221}
222
223#[derive(Debug, Clone, Serialize, Deserialize)]
225pub struct NodeInfo {
226 pub id: String,
227 pub address: String,
228 pub status: NodeStatus,
229 pub last_heartbeat: u64,
230 pub shards: Vec<usize>,
231 pub capacity_bytes: usize,
232 pub used_bytes: usize,
233}
234
235#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
237pub enum NodeStatus {
238 Healthy,
239 Degraded,
240 Failed,
241 Joining,
242 Leaving,
243}
244
245#[derive(Debug, Clone, Serialize, Deserialize)]
247pub struct ShardInfo {
248 pub id: usize,
249 pub primary_node: String,
250 pub replica_nodes: Vec<String>,
251 pub size_bytes: usize,
252 pub num_vectors: usize,
253}
254
255#[derive(Debug, Clone)]
257pub struct ShardQueryResult {
258 pub shard_id: usize,
259 pub node_id: String,
260 pub results: Vec<(String, f32)>, pub latency_ms: f64,
262}
263
264#[derive(Debug, Clone, Default)]
266pub struct DistributedStats {
267 pub total_nodes: usize,
268 pub healthy_nodes: usize,
269 pub total_shards: usize,
270 pub total_vectors: usize,
271 pub total_bytes: usize,
272 pub avg_shard_size_bytes: usize,
273 pub rebalances_performed: usize,
274 pub queries_total: u64,
275 pub queries_failed: u64,
276 pub avg_query_latency_ms: f64,
277}
278
279pub struct ConsistentHashRing {
281 virtual_nodes: usize,
282 ring: Vec<(u64, String)>, }
284
285impl ConsistentHashRing {
286 pub fn new(virtual_nodes: usize) -> Self {
287 Self {
288 virtual_nodes,
289 ring: Vec::new(),
290 }
291 }
292
293 pub fn add_node(&mut self, node_id: &str) {
294 for i in 0..self.virtual_nodes {
295 let key = format!("{}:{}", node_id, i);
296 let hash = Self::hash(&key);
297 self.ring.push((hash, node_id.to_string()));
298 }
299 self.ring.sort_by_key(|&(h, _)| h);
300 }
301
302 pub fn remove_node(&mut self, node_id: &str) {
303 self.ring.retain(|(_, id)| id != node_id);
304 }
305
306 pub fn get_node(&self, key: &str) -> Option<String> {
307 if self.ring.is_empty() {
308 return None;
309 }
310
311 let hash = Self::hash(key);
312
313 let idx = self.ring.partition_point(|&(h, _)| h < hash);
315 let idx = if idx >= self.ring.len() { 0 } else { idx };
316
317 Some(self.ring[idx].1.clone())
318 }
319
320 pub fn get_nodes(&self, key: &str, count: usize) -> Vec<String> {
321 if self.ring.is_empty() {
322 return vec![];
323 }
324
325 let hash = Self::hash(key);
326 let mut seen = HashSet::new();
327 let mut nodes = Vec::new();
328
329 let start_idx = self.ring.partition_point(|&(h, _)| h < hash);
330
331 for i in 0..self.ring.len() {
332 let idx = (start_idx + i) % self.ring.len();
333 let node_id = &self.ring[idx].1;
334
335 if seen.insert(node_id.clone()) {
336 nodes.push(node_id.clone());
337 if nodes.len() >= count {
338 break;
339 }
340 }
341 }
342
343 nodes
344 }
345
346 fn hash(key: &str) -> u64 {
347 use std::collections::hash_map::DefaultHasher;
348 use std::hash::{Hash, Hasher};
349
350 let mut hasher = DefaultHasher::new();
351 key.hash(&mut hasher);
352 hasher.finish()
353 }
354}
355
356#[cfg(not(feature = "async"))]
358pub struct DistributedStore {
359 config: DistributedConfig,
360 nodes: HashMap<String, NodeInfo>,
361 shards: HashMap<usize, ShardInfo>,
362 hash_ring: ConsistentHashRing,
363 stats: DistributedStats,
364}
365
366#[cfg(not(feature = "async"))]
367impl DistributedStore {
368 pub fn create(config: DistributedConfig) -> Result<Self> {
369 if config.num_shards == 0 {
371 return Err(anyhow::anyhow!(
372 "Number of shards must be at least 1, got 0. Use DistributedConfig::with_num_shards() to set a valid shard count."
373 ));
374 }
375
376 let mut shards = HashMap::new();
377 for i in 0..config.num_shards {
378 shards.insert(
379 i,
380 ShardInfo {
381 id: i,
382 primary_node: String::new(),
383 replica_nodes: Vec::new(),
384 size_bytes: 0,
385 num_vectors: 0,
386 },
387 );
388 }
389
390 Ok(Self {
391 config,
392 nodes: HashMap::new(),
393 shards,
394 hash_ring: ConsistentHashRing::new(150), stats: DistributedStats::default(),
396 })
397 }
398
399 pub fn add_node(&mut self, node_id: &str, address: &str) -> Result<()> {
400 let node = NodeInfo {
401 id: node_id.to_string(),
402 address: address.to_string(),
403 status: NodeStatus::Joining,
404 last_heartbeat: current_timestamp(),
405 shards: Vec::new(),
406 capacity_bytes: 10 * 1024 * 1024 * 1024, used_bytes: 0,
408 };
409
410 self.nodes.insert(node_id.to_string(), node);
411 self.hash_ring.add_node(node_id);
412 self.stats.total_nodes += 1;
413
414 if self.config.auto_rebalance {
416 self.rebalance()?;
417 }
418
419 Ok(())
420 }
421
422 pub fn remove_node(&mut self, node_id: &str) -> Result<()> {
423 self.nodes.remove(node_id);
424 self.hash_ring.remove_node(node_id);
425 self.stats.total_nodes = self.stats.total_nodes.saturating_sub(1);
426
427 if self.config.auto_rebalance {
428 self.rebalance()?;
429 }
430
431 Ok(())
432 }
433
434 pub fn get_shard_id(&self, key: &str) -> usize {
435 match self.config.sharding_strategy {
436 ShardingStrategy::Hash => {
437 use std::collections::hash_map::DefaultHasher;
438 use std::hash::{Hash, Hasher};
439
440 let mut hasher = DefaultHasher::new();
441 key.hash(&mut hasher);
442 (hasher.finish() as usize) % self.config.num_shards
443 }
444 ShardingStrategy::ConsistentHash => {
445 let node = self.hash_ring.get_node(key).unwrap_or_default();
447 let sum = node
449 .as_bytes()
450 .iter()
451 .fold(0u32, |acc, &b| acc.wrapping_add(b as u32));
452 (sum as usize) % self.config.num_shards
453 }
454 ShardingStrategy::Range => {
455 key.as_bytes().first().copied().unwrap_or(0) as usize % self.config.num_shards
457 }
458 ShardingStrategy::Random => {
459 key.len() % self.config.num_shards
461 }
462 }
463 }
464
465 pub fn rebalance(&mut self) -> Result<()> {
466 if self.nodes.is_empty() {
468 return Ok(());
469 }
470
471 let node_ids: Vec<String> = self.nodes.keys().cloned().collect();
472
473 for (shard_id, shard_info) in &mut self.shards {
474 let idx = *shard_id % node_ids.len();
475 shard_info.primary_node = node_ids[idx].clone();
476
477 shard_info.replica_nodes.clear();
479 for i in 1..self.config.replication_factor {
480 let replica_idx = (idx + i) % node_ids.len();
481 shard_info.replica_nodes.push(node_ids[replica_idx].clone());
482 }
483 }
484
485 self.stats.rebalances_performed += 1;
486
487 Ok(())
488 }
489
490 pub fn stats(&self) -> &DistributedStats {
491 &self.stats
492 }
493
494 pub fn cluster_health(&self) -> f32 {
495 if self.stats.total_nodes == 0 {
496 return 0.0;
497 }
498 self.stats.healthy_nodes as f32 / self.stats.total_nodes as f32
499 }
500}
501
502#[cfg(feature = "async")]
504pub struct DistributedStore {
505 config: DistributedConfig,
506 nodes: Arc<RwLock<HashMap<String, NodeInfo>>>,
507 shards: Arc<RwLock<HashMap<usize, ShardInfo>>>,
508 hash_ring: Arc<RwLock<ConsistentHashRing>>,
509 stats: Arc<RwLock<DistributedStats>>,
510 raft_node: Option<Arc<raft::RaftNode>>,
512}
513
514#[cfg(feature = "async")]
515impl DistributedStore {
516 pub async fn create(config: DistributedConfig) -> Result<Self> {
517 let mut shards = HashMap::new();
518 for i in 0..config.num_shards {
519 shards.insert(
520 i,
521 ShardInfo {
522 id: i,
523 primary_node: String::new(),
524 replica_nodes: Vec::new(),
525 size_bytes: 0,
526 num_vectors: 0,
527 },
528 );
529 }
530
531 Ok(Self {
532 config,
533 nodes: Arc::new(RwLock::new(HashMap::new())),
534 shards: Arc::new(RwLock::new(shards)),
535 hash_ring: Arc::new(RwLock::new(ConsistentHashRing::new(150))),
536 stats: Arc::new(RwLock::new(DistributedStats::default())),
537 raft_node: None,
538 })
539 }
540
541 pub async fn enable_raft(&mut self, node_id: String, peer_ids: Vec<String>) -> Result<()> {
550 let raft_config = raft::RaftConfig {
551 node_id,
552 peers: peer_ids,
553 ..Default::default()
554 };
555
556 let raft_node = raft::RaftNode::new(raft_config);
557 self.raft_node = Some(Arc::new(raft_node));
558
559 Ok(())
560 }
561
562 pub fn is_raft_enabled(&self) -> bool {
564 self.raft_node.is_some()
565 }
566
567 pub fn raft_node(&self) -> Option<Arc<raft::RaftNode>> {
569 self.raft_node.clone()
570 }
571
572 pub async fn add_node(&self, node_id: &str, address: &str) -> Result<()> {
573 if let Some(raft) = &self.raft_node {
575 if !raft.is_leader().await {
577 return Err(anyhow!("Not the leader - cannot add nodes"));
578 }
579
580 let command = raft::Command::Insert {
582 id: format!("node:{}", node_id),
583 vector: vec![], metadata: serde_json::json!({
585 "type": "add_node",
586 "node_id": node_id,
587 "address": address,
588 }),
589 };
590
591 raft.append_entry(command).await.map_err(|e| anyhow!(e))?;
593
594 }
597
598 let node = NodeInfo {
599 id: node_id.to_string(),
600 address: address.to_string(),
601 status: NodeStatus::Joining,
602 last_heartbeat: current_timestamp(),
603 shards: Vec::new(),
604 capacity_bytes: 10 * 1024 * 1024 * 1024,
605 used_bytes: 0,
606 };
607
608 {
609 let mut nodes = self.nodes.write().await;
610 nodes.insert(node_id.to_string(), node);
611 }
612
613 {
614 let mut ring = self.hash_ring.write().await;
615 ring.add_node(node_id);
616 }
617
618 {
619 let mut stats = self.stats.write().await;
620 stats.total_nodes += 1;
621 }
622
623 if self.config.auto_rebalance {
624 self.rebalance().await?;
625 }
626
627 Ok(())
628 }
629
630 pub async fn remove_node(&self, node_id: &str) -> Result<()> {
631 if let Some(raft) = &self.raft_node {
633 if !raft.is_leader().await {
635 return Err(anyhow!("Not the leader - cannot remove nodes"));
636 }
637
638 let command = raft::Command::Delete {
640 id: format!("node:{}", node_id),
641 };
642
643 raft.append_entry(command).await.map_err(|e| anyhow!(e))?;
645
646 }
649
650 {
651 let mut nodes = self.nodes.write().await;
652 nodes.remove(node_id);
653 }
654
655 {
656 let mut ring = self.hash_ring.write().await;
657 ring.remove_node(node_id);
658 }
659
660 {
661 let mut stats = self.stats.write().await;
662 stats.total_nodes = stats.total_nodes.saturating_sub(1);
663 }
664
665 if self.config.auto_rebalance {
666 self.rebalance().await?;
667 }
668
669 Ok(())
670 }
671
672 pub async fn get_shard_id(&self, key: &str) -> usize {
673 match self.config.sharding_strategy {
674 ShardingStrategy::Hash => {
675 use std::collections::hash_map::DefaultHasher;
676 use std::hash::{Hash, Hasher};
677
678 let mut hasher = DefaultHasher::new();
679 key.hash(&mut hasher);
680 (hasher.finish() as usize) % self.config.num_shards
681 }
682 ShardingStrategy::ConsistentHash => {
683 let ring = self.hash_ring.read().await;
684 let node = ring.get_node(key).unwrap_or_default();
685 let sum = node
686 .as_bytes()
687 .iter()
688 .fold(0u32, |acc, &b| acc.wrapping_add(b as u32));
689 (sum as usize) % self.config.num_shards
690 }
691 ShardingStrategy::Range => {
692 key.as_bytes().first().copied().unwrap_or(0) as usize % self.config.num_shards
693 }
694 ShardingStrategy::Random => key.len() % self.config.num_shards,
695 }
696 }
697
698 pub async fn rebalance(&self) -> Result<()> {
699 let nodes = self.nodes.read().await;
700 if nodes.is_empty() {
701 return Ok(());
702 }
703
704 let node_ids: Vec<String> = nodes.keys().cloned().collect();
705 drop(nodes);
706
707 let mut shards = self.shards.write().await;
708
709 for (shard_id, shard_info) in shards.iter_mut() {
710 let idx = *shard_id % node_ids.len();
711 shard_info.primary_node = node_ids[idx].clone();
712
713 shard_info.replica_nodes.clear();
714 for i in 1..self.config.replication_factor {
715 let replica_idx = (idx + i) % node_ids.len();
716 shard_info.replica_nodes.push(node_ids[replica_idx].clone());
717 }
718 }
719
720 let mut stats = self.stats.write().await;
721 stats.rebalances_performed += 1;
722
723 Ok(())
724 }
725
726 pub async fn stats(&self) -> DistributedStats {
727 self.stats.read().await.clone()
728 }
729
730 pub async fn cluster_health(&self) -> f32 {
731 let stats = self.stats.read().await;
732 if stats.total_nodes == 0 {
733 return 0.0;
734 }
735 stats.healthy_nodes as f32 / stats.total_nodes as f32
736 }
737
738 pub async fn get_replicas(&self, shard_id: usize) -> Result<Vec<String>> {
740 let shards = self.shards.read().await;
741 let shard = shards
742 .get(&shard_id)
743 .ok_or_else(|| anyhow!("Shard {} not found", shard_id))?;
744
745 Ok(shard.replica_nodes.clone())
746 }
747
748 pub async fn sync_to_replicas(&self, shard_id: usize, data: Vec<u8>) -> Result<()> {
753 let replicas = self.get_replicas(shard_id).await?;
754
755 if replicas.is_empty() {
756 return Ok(());
757 }
758
759 if let Some(raft) = &self.raft_node {
761 if !raft.is_leader().await {
762 return Err(anyhow!("Not the leader - cannot sync replicas"));
763 }
764
765 let command = raft::Command::Update {
766 id: format!("shard:{}:sync", shard_id),
767 vector: vec![],
768 metadata: serde_json::json!({
769 "type": "replica_sync",
770 "shard_id": shard_id,
771 "data_size": data.len(),
772 }),
773 };
774
775 raft.append_entry(command).await.map_err(|e| anyhow!(e))?;
776 }
777
778 match self.config.consistency_level {
781 ConsistencyLevel::All => {
782 for _replica in &replicas {
784 }
787 }
788 ConsistencyLevel::Quorum => {
789 let quorum_size = (replicas.len() / 2) + 1;
791 for _i in 0..quorum_size {
792 }
794 }
795 ConsistencyLevel::One => {
796 }
799 }
800
801 Ok(())
802 }
803
804 pub async fn query_from_replicas(
809 &self,
810 shard_id: usize,
811 query: Vec<f32>,
812 k: usize,
813 ) -> Result<Vec<(String, f32)>> {
814 let shards = self.shards.read().await;
815 let shard = shards
816 .get(&shard_id)
817 .ok_or_else(|| anyhow!("Shard {} not found", shard_id))?;
818
819 let mut available_nodes = vec![shard.primary_node.clone()];
821 available_nodes.extend(shard.replica_nodes.iter().cloned());
822
823 if available_nodes.is_empty() {
824 return Err(anyhow!("No nodes available for shard {}", shard_id));
825 }
826
827 match self.config.consistency_level {
829 ConsistencyLevel::All => {
830 Ok(vec![])
833 }
834 ConsistencyLevel::Quorum => {
835 Ok(vec![])
838 }
839 ConsistencyLevel::One => {
840 Ok(vec![])
844 }
845 }
846 }
847
848 pub async fn promote_replica(&self, shard_id: usize, new_primary: String) -> Result<()> {
852 if let Some(raft) = &self.raft_node {
854 if !raft.is_leader().await {
855 return Err(anyhow!("Not the leader - cannot promote replica"));
856 }
857
858 let command = raft::Command::Update {
859 id: format!("shard:{}:promote", shard_id),
860 vector: vec![],
861 metadata: serde_json::json!({
862 "type": "promote_replica",
863 "shard_id": shard_id,
864 "new_primary": new_primary,
865 }),
866 };
867
868 raft.append_entry(command).await.map_err(|e| anyhow!(e))?;
869 }
870
871 let mut shards = self.shards.write().await;
872 let shard = shards
873 .get_mut(&shard_id)
874 .ok_or_else(|| anyhow!("Shard {} not found", shard_id))?;
875
876 if !shard.replica_nodes.contains(&new_primary) {
878 return Err(anyhow!(
879 "Node {} is not a replica of shard {}",
880 new_primary,
881 shard_id
882 ));
883 }
884
885 shard.replica_nodes.retain(|n| n != &new_primary);
887
888 if !shard.primary_node.is_empty() {
890 shard.replica_nodes.push(shard.primary_node.clone());
891 }
892
893 shard.primary_node = new_primary;
894
895 Ok(())
896 }
897
898 pub async fn insert(&self, id: &str, vector: Vec<f32>) -> Result<()> {
899 let shard_id = self.get_shard_id(id).await;
900
901 let mut stats = self.stats.write().await;
903 stats.total_vectors += 1;
904
905 Ok(())
906 }
907
908 pub async fn query(&self, query: Vec<f32>, k: usize) -> Result<Vec<(String, f32)>> {
909 let mut stats = self.stats.write().await;
912 stats.queries_total += 1;
913
914 Ok(vec![])
916 }
917}
918
919fn current_timestamp() -> u64 {
920 std::time::SystemTime::now()
921 .duration_since(std::time::UNIX_EPOCH)
922 .unwrap()
923 .as_secs()
924}
925
926#[cfg(test)]
927mod tests {
928 use super::*;
929
930 #[test]
931 fn test_consistent_hash_ring() {
932 let mut ring = ConsistentHashRing::new(100);
933
934 ring.add_node("node1");
935 ring.add_node("node2");
936 ring.add_node("node3");
937
938 let node1 = ring.get_node("key1").unwrap();
940 let node2 = ring.get_node("key1").unwrap();
941 assert_eq!(node1, node2);
942
943 let nodes = ring.get_nodes("key1", 3);
945 assert_eq!(nodes.len(), 3);
946 assert!(
947 nodes.contains(&"node1".to_string())
948 || nodes.contains(&"node2".to_string())
949 || nodes.contains(&"node3".to_string())
950 );
951 }
952
953 #[test]
954 fn test_sharding_strategies() {
955 let config = DistributedConfig::new().with_num_shards(4);
956
957 #[cfg(not(feature = "async"))]
958 {
959 let store = DistributedStore::create(config).unwrap();
960
961 let shard1 = store.get_shard_id("key1");
962 let shard2 = store.get_shard_id("key1");
963 assert_eq!(shard1, shard2); assert!(shard1 < 4); }
967 }
968
969 #[cfg(not(feature = "async"))]
970 #[test]
971 fn test_add_remove_nodes() {
972 let config = DistributedConfig::new();
973 let mut store = DistributedStore::create(config).unwrap();
974
975 store.add_node("node1", "127.0.0.1:8001").unwrap();
976 store.add_node("node2", "127.0.0.1:8002").unwrap();
977
978 assert_eq!(store.stats().total_nodes, 2);
979
980 store.remove_node("node1").unwrap();
981 assert_eq!(store.stats().total_nodes, 1);
982 }
983
984 #[cfg(feature = "async")]
985 #[tokio::test]
986 async fn test_async_distributed_store() {
987 let config = DistributedConfig::new();
988 let store = DistributedStore::create(config).await.unwrap();
989
990 store.add_node("node1", "127.0.0.1:8001").await.unwrap();
991 store.add_node("node2", "127.0.0.1:8002").await.unwrap();
992
993 let stats = store.stats().await;
994 assert_eq!(stats.total_nodes, 2);
995
996 store.insert("doc1", vec![0.1, 0.2, 0.3]).await.unwrap();
997
998 let stats = store.stats().await;
999 assert_eq!(stats.total_vectors, 1);
1000 }
1001
1002 #[test]
1003 fn test_cluster_health() {
1004 let config = DistributedConfig::new();
1005
1006 #[cfg(not(feature = "async"))]
1007 {
1008 let mut store = DistributedStore::create(config).unwrap();
1009 store.add_node("node1", "127.0.0.1:8001").unwrap();
1010
1011 store.stats.healthy_nodes = 1;
1012 assert_eq!(store.cluster_health(), 1.0);
1013
1014 store.stats.healthy_nodes = 0;
1015 assert_eq!(store.cluster_health(), 0.0);
1016 }
1017 }
1018
1019 #[cfg(feature = "async")]
1020 #[tokio::test]
1021 async fn test_raft_integration() {
1022 let config = DistributedConfig::new();
1023 let mut store = DistributedStore::create(config).await.unwrap();
1024
1025 assert!(!store.is_raft_enabled());
1027
1028 store
1030 .enable_raft("node1".to_string(), vec![])
1031 .await
1032 .unwrap();
1033 assert!(store.is_raft_enabled());
1034
1035 let raft = store.raft_node().unwrap();
1037 assert!(!raft.is_leader().await);
1038
1039 raft.start_election().await;
1041 assert!(raft.is_leader().await);
1042 }
1043
1044 #[cfg(feature = "async")]
1045 #[tokio::test]
1046 async fn test_raft_add_node_leader_check() {
1047 let config = DistributedConfig::new();
1048 let mut store = DistributedStore::create(config).await.unwrap();
1049
1050 store
1052 .enable_raft("leader".to_string(), vec![])
1053 .await
1054 .unwrap();
1055 let raft = store.raft_node().unwrap();
1056
1057 raft.start_election().await;
1059 assert!(raft.is_leader().await);
1060
1061 let result = store.add_node("node1", "127.0.0.1:8001").await;
1063 assert!(result.is_ok());
1064
1065 let stats = store.stats().await;
1066 assert_eq!(stats.total_nodes, 1);
1067 }
1068
1069 #[cfg(feature = "async")]
1070 #[tokio::test]
1071 async fn test_raft_add_node_not_leader_fails() {
1072 let config = DistributedConfig::new();
1073 let mut store = DistributedStore::create(config).await.unwrap();
1074
1075 store
1077 .enable_raft("follower".to_string(), vec!["leader".to_string()])
1078 .await
1079 .unwrap();
1080
1081 let raft = store.raft_node().unwrap();
1082 assert!(!raft.is_leader().await);
1083
1084 let result = store.add_node("node1", "127.0.0.1:8001").await;
1086 assert!(result.is_err());
1087 assert!(result.unwrap_err().to_string().contains("Not the leader"));
1088 }
1089
1090 #[cfg(feature = "async")]
1091 #[tokio::test]
1092 async fn test_raft_remove_node_leader_check() {
1093 let config = DistributedConfig::new();
1094 let mut store = DistributedStore::create(config).await.unwrap();
1095
1096 store
1098 .enable_raft("leader".to_string(), vec![])
1099 .await
1100 .unwrap();
1101 let raft = store.raft_node().unwrap();
1102 raft.start_election().await;
1103 assert!(raft.is_leader().await);
1104
1105 let mut store_without_raft = DistributedStore::create(DistributedConfig::new())
1107 .await
1108 .unwrap();
1109 store_without_raft
1110 .add_node("node1", "127.0.0.1:8001")
1111 .await
1112 .unwrap();
1113
1114 store
1116 .enable_raft("leader".to_string(), vec![])
1117 .await
1118 .unwrap();
1119 let raft = store.raft_node().unwrap();
1120 raft.start_election().await;
1121
1122 store.add_node("node1", "127.0.0.1:8001").await.unwrap();
1124
1125 let result = store.remove_node("node1").await;
1127 assert!(result.is_ok());
1128 }
1129
1130 #[cfg(feature = "async")]
1131 #[tokio::test]
1132 async fn test_consistent_hashing_with_raft() {
1133 let config = DistributedConfig {
1134 sharding_strategy: ShardingStrategy::ConsistentHash,
1135 ..Default::default()
1136 };
1137 let mut store = DistributedStore::create(config).await.unwrap();
1138
1139 store
1141 .enable_raft("leader".to_string(), vec![])
1142 .await
1143 .unwrap();
1144 let raft = store.raft_node().unwrap();
1145 raft.start_election().await;
1146
1147 store.add_node("node1", "127.0.0.1:8001").await.unwrap();
1149 store.add_node("node2", "127.0.0.1:8002").await.unwrap();
1150 store.add_node("node3", "127.0.0.1:8003").await.unwrap();
1151
1152 let key = "test-key";
1154 let shard_id = store.get_shard_id(key).await;
1155 assert!(shard_id < store.config.num_shards);
1156
1157 let stats = store.stats().await;
1159 assert_eq!(stats.total_nodes, 3);
1160 }
1161
1162 #[cfg(feature = "async")]
1163 #[tokio::test]
1164 async fn test_replica_assignment() {
1165 let config = DistributedConfig {
1166 replication_factor: 3,
1167 ..Default::default()
1168 };
1169 let store = DistributedStore::create(config).await.unwrap();
1170
1171 store.add_node("node1", "127.0.0.1:8001").await.unwrap();
1173 store.add_node("node2", "127.0.0.1:8002").await.unwrap();
1174 store.add_node("node3", "127.0.0.1:8003").await.unwrap();
1175
1176 store.rebalance().await.unwrap();
1178
1179 let shards = store.shards.read().await;
1181 for (shard_id, shard) in shards.iter() {
1182 assert!(
1183 !shard.primary_node.is_empty(),
1184 "Shard {} has no primary",
1185 shard_id
1186 );
1187 assert_eq!(
1188 shard.replica_nodes.len(),
1189 2,
1190 "Shard {} should have 2 replicas",
1191 shard_id
1192 );
1193 }
1194 }
1195
1196 #[cfg(feature = "async")]
1197 #[tokio::test]
1198 async fn test_get_replicas() {
1199 let config = DistributedConfig {
1200 replication_factor: 3,
1201 num_shards: 4,
1202 ..Default::default()
1203 };
1204 let store = DistributedStore::create(config).await.unwrap();
1205
1206 store.add_node("node1", "127.0.0.1:8001").await.unwrap();
1208 store.add_node("node2", "127.0.0.1:8002").await.unwrap();
1209 store.add_node("node3", "127.0.0.1:8003").await.unwrap();
1210 store.rebalance().await.unwrap();
1211
1212 let replicas = store.get_replicas(0).await.unwrap();
1214 assert_eq!(replicas.len(), 2); let result = store.get_replicas(999).await;
1218 assert!(result.is_err());
1219 }
1220
1221 #[cfg(feature = "async")]
1222 #[tokio::test]
1223 async fn test_sync_to_replicas_eventual_consistency() {
1224 let config = DistributedConfig {
1225 replication_factor: 3,
1226 consistency_level: ConsistencyLevel::One,
1227 ..Default::default()
1228 };
1229 let store = DistributedStore::create(config).await.unwrap();
1230
1231 store.add_node("node1", "127.0.0.1:8001").await.unwrap();
1233 store.add_node("node2", "127.0.0.1:8002").await.unwrap();
1234 store.add_node("node3", "127.0.0.1:8003").await.unwrap();
1235 store.rebalance().await.unwrap();
1236
1237 let result = store.sync_to_replicas(0, vec![1, 2, 3, 4]).await;
1239 assert!(result.is_ok());
1240 }
1241
1242 #[cfg(feature = "async")]
1243 #[tokio::test]
1244 async fn test_sync_to_replicas_with_raft() {
1245 let config = DistributedConfig {
1246 replication_factor: 3,
1247 consistency_level: ConsistencyLevel::Quorum,
1248 ..Default::default()
1249 };
1250 let mut store = DistributedStore::create(config).await.unwrap();
1251
1252 store
1254 .enable_raft("leader".to_string(), vec![])
1255 .await
1256 .unwrap();
1257 let raft = store.raft_node().unwrap();
1258 raft.start_election().await;
1259
1260 store.add_node("node1", "127.0.0.1:8001").await.unwrap();
1262 store.add_node("node2", "127.0.0.1:8002").await.unwrap();
1263 store.add_node("node3", "127.0.0.1:8003").await.unwrap();
1264 store.rebalance().await.unwrap();
1265
1266 let result = store.sync_to_replicas(0, vec![1, 2, 3, 4]).await;
1268 assert!(result.is_ok());
1269 }
1270
1271 #[cfg(feature = "async")]
1272 #[tokio::test]
1273 async fn test_query_from_replicas() {
1274 let config = DistributedConfig {
1275 replication_factor: 3,
1276 consistency_level: ConsistencyLevel::One,
1277 ..Default::default()
1278 };
1279 let store = DistributedStore::create(config).await.unwrap();
1280
1281 store.add_node("node1", "127.0.0.1:8001").await.unwrap();
1283 store.add_node("node2", "127.0.0.1:8002").await.unwrap();
1284 store.add_node("node3", "127.0.0.1:8003").await.unwrap();
1285 store.rebalance().await.unwrap();
1286
1287 let result = store.query_from_replicas(0, vec![0.1, 0.2, 0.3], 10).await;
1289 assert!(result.is_ok());
1290 }
1291
1292 #[cfg(feature = "async")]
1293 #[tokio::test]
1294 async fn test_query_strong_consistency() {
1295 let config = DistributedConfig {
1296 replication_factor: 3,
1297 consistency_level: ConsistencyLevel::All,
1298 ..Default::default()
1299 };
1300 let store = DistributedStore::create(config).await.unwrap();
1301
1302 store.add_node("node1", "127.0.0.1:8001").await.unwrap();
1303 store.add_node("node2", "127.0.0.1:8002").await.unwrap();
1304 store.rebalance().await.unwrap();
1305
1306 let result = store.query_from_replicas(0, vec![0.1, 0.2, 0.3], 10).await;
1308 assert!(result.is_ok());
1309 }
1310
1311 #[cfg(feature = "async")]
1312 #[tokio::test]
1313 async fn test_promote_replica() {
1314 let config = DistributedConfig {
1315 replication_factor: 3,
1316 ..Default::default()
1317 };
1318 let store = DistributedStore::create(config).await.unwrap();
1319
1320 store.add_node("node1", "127.0.0.1:8001").await.unwrap();
1322 store.add_node("node2", "127.0.0.1:8002").await.unwrap();
1323 store.add_node("node3", "127.0.0.1:8003").await.unwrap();
1324 store.rebalance().await.unwrap();
1325
1326 let shards = store.shards.read().await;
1328 let shard = shards.get(&0).unwrap();
1329 let old_primary = shard.primary_node.clone();
1330 let new_primary = shard.replica_nodes[0].clone();
1331 drop(shards);
1332
1333 store.promote_replica(0, new_primary.clone()).await.unwrap();
1335
1336 let shards = store.shards.read().await;
1338 let shard = shards.get(&0).unwrap();
1339 assert_eq!(shard.primary_node, new_primary);
1340 assert!(shard.replica_nodes.contains(&old_primary));
1341 assert!(!shard.replica_nodes.contains(&new_primary));
1342 }
1343
1344 #[cfg(feature = "async")]
1345 #[tokio::test]
1346 async fn test_promote_replica_with_raft() {
1347 let config = DistributedConfig {
1348 replication_factor: 3,
1349 ..Default::default()
1350 };
1351 let mut store = DistributedStore::create(config).await.unwrap();
1352
1353 store
1355 .enable_raft("leader".to_string(), vec![])
1356 .await
1357 .unwrap();
1358 let raft = store.raft_node().unwrap();
1359 raft.start_election().await;
1360
1361 store.add_node("node1", "127.0.0.1:8001").await.unwrap();
1363 store.add_node("node2", "127.0.0.1:8002").await.unwrap();
1364 store.add_node("node3", "127.0.0.1:8003").await.unwrap();
1365 store.rebalance().await.unwrap();
1366
1367 let shards = store.shards.read().await;
1369 let shard = shards.get(&0).unwrap();
1370 let new_primary = shard.replica_nodes[0].clone();
1371 drop(shards);
1372
1373 let result = store.promote_replica(0, new_primary).await;
1375 assert!(result.is_ok());
1376 }
1377
1378 #[cfg(feature = "async")]
1379 #[tokio::test]
1380 async fn test_promote_non_replica_fails() {
1381 let config = DistributedConfig::default();
1382 let store = DistributedStore::create(config).await.unwrap();
1383
1384 store.add_node("node1", "127.0.0.1:8001").await.unwrap();
1385 store.add_node("node2", "127.0.0.1:8002").await.unwrap();
1386 store.rebalance().await.unwrap();
1387
1388 let result = store
1390 .promote_replica(0, "node-not-replica".to_string())
1391 .await;
1392 assert!(result.is_err());
1393 assert!(result.unwrap_err().to_string().contains("not a replica"));
1394 }
1395}