1pub mod consensus;
10pub mod discovery;
11pub mod shard;
12
13use chrono::{DateTime, Utc};
14use dashmap::DashMap;
15use parking_lot::RwLock;
16use serde::{Deserialize, Serialize};
17use std::collections::HashMap;
18use std::net::SocketAddr;
19use std::sync::Arc;
20use std::time::Duration;
21use thiserror::Error;
22use tracing::{debug, error, info, warn};
23use uuid::Uuid;
24
25pub use consensus::DagConsensus;
26pub use discovery::{DiscoveryService, GossipDiscovery, StaticDiscovery};
27pub use shard::{ConsistentHashRing, ShardRouter};
28
29#[derive(Debug, Error)]
31pub enum ClusterError {
32 #[error("Node not found: {0}")]
33 NodeNotFound(String),
34
35 #[error("Shard not found: {0}")]
36 ShardNotFound(u32),
37
38 #[error("Invalid configuration: {0}")]
39 InvalidConfig(String),
40
41 #[error("Consensus error: {0}")]
42 ConsensusError(String),
43
44 #[error("Discovery error: {0}")]
45 DiscoveryError(String),
46
47 #[error("Network error: {0}")]
48 NetworkError(String),
49
50 #[error("Serialization error: {0}")]
51 SerializationError(String),
52
53 #[error("IO error: {0}")]
54 IoError(#[from] std::io::Error),
55}
56
57pub type Result<T> = std::result::Result<T, ClusterError>;
58
59#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
61pub enum NodeStatus {
62 Leader,
64 Follower,
66 Candidate,
68 Offline,
70}
71
72#[derive(Debug, Clone, Serialize, Deserialize)]
74pub struct ClusterNode {
75 pub node_id: String,
77 pub address: SocketAddr,
79 pub status: NodeStatus,
81 pub last_seen: DateTime<Utc>,
83 pub metadata: HashMap<String, String>,
85 pub capacity: f64,
87}
88
89impl ClusterNode {
90 pub fn new(node_id: String, address: SocketAddr) -> Self {
92 Self {
93 node_id,
94 address,
95 status: NodeStatus::Follower,
96 last_seen: Utc::now(),
97 metadata: HashMap::new(),
98 capacity: 1.0,
99 }
100 }
101
102 pub fn is_healthy(&self, timeout: Duration) -> bool {
104 let now = Utc::now();
105 let elapsed = now
106 .signed_duration_since(self.last_seen)
107 .to_std()
108 .unwrap_or(Duration::MAX);
109 elapsed < timeout
110 }
111
112 pub fn heartbeat(&mut self) {
114 self.last_seen = Utc::now();
115 }
116}
117
118#[derive(Debug, Clone, Serialize, Deserialize)]
120pub struct ShardInfo {
121 pub shard_id: u32,
123 pub primary_node: String,
125 pub replica_nodes: Vec<String>,
127 pub vector_count: usize,
129 pub status: ShardStatus,
131 pub created_at: DateTime<Utc>,
133 pub modified_at: DateTime<Utc>,
135}
136
137#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
139pub enum ShardStatus {
140 Active,
142 Migrating,
144 Replicating,
146 Offline,
148}
149
150#[derive(Debug, Clone, Serialize, Deserialize)]
152pub struct ClusterConfig {
153 pub replication_factor: usize,
155 pub shard_count: u32,
157 pub heartbeat_interval: Duration,
159 pub node_timeout: Duration,
161 pub enable_consensus: bool,
163 pub min_quorum_size: usize,
165}
166
167impl Default for ClusterConfig {
168 fn default() -> Self {
169 Self {
170 replication_factor: 3,
171 shard_count: 64,
172 heartbeat_interval: Duration::from_secs(5),
173 node_timeout: Duration::from_secs(30),
174 enable_consensus: true,
175 min_quorum_size: 2,
176 }
177 }
178}
179
180pub struct ClusterManager {
182 config: ClusterConfig,
184 nodes: Arc<DashMap<String, ClusterNode>>,
186 shards: Arc<DashMap<u32, ShardInfo>>,
188 hash_ring: Arc<RwLock<ConsistentHashRing>>,
190 router: Arc<ShardRouter>,
192 consensus: Option<Arc<DagConsensus>>,
194 discovery: Box<dyn DiscoveryService>,
196 node_id: String,
198}
199
200impl ClusterManager {
201 pub fn new(
203 config: ClusterConfig,
204 node_id: String,
205 discovery: Box<dyn DiscoveryService>,
206 ) -> Result<Self> {
207 let nodes = Arc::new(DashMap::new());
208 let shards = Arc::new(DashMap::new());
209 let hash_ring = Arc::new(RwLock::new(ConsistentHashRing::new(
210 config.replication_factor,
211 )));
212 let router = Arc::new(ShardRouter::new(config.shard_count));
213
214 let consensus = if config.enable_consensus {
215 Some(Arc::new(DagConsensus::new(
216 node_id.clone(),
217 config.min_quorum_size,
218 )))
219 } else {
220 None
221 };
222
223 Ok(Self {
224 config,
225 nodes,
226 shards,
227 hash_ring,
228 router,
229 consensus,
230 discovery,
231 node_id,
232 })
233 }
234
235 pub async fn add_node(&self, node: ClusterNode) -> Result<()> {
237 info!("Adding node {} to cluster", node.node_id);
238
239 {
241 let mut ring = self.hash_ring.write();
242 ring.add_node(node.node_id.clone());
243 }
244
245 self.nodes.insert(node.node_id.clone(), node.clone());
247
248 self.rebalance_shards().await?;
250
251 info!("Node {} successfully added", node.node_id);
252 Ok(())
253 }
254
255 pub async fn remove_node(&self, node_id: &str) -> Result<()> {
257 info!("Removing node {} from cluster", node_id);
258
259 {
261 let mut ring = self.hash_ring.write();
262 ring.remove_node(node_id);
263 }
264
265 self.nodes.remove(node_id);
267
268 self.rebalance_shards().await?;
270
271 info!("Node {} successfully removed", node_id);
272 Ok(())
273 }
274
275 pub fn get_node(&self, node_id: &str) -> Option<ClusterNode> {
277 self.nodes.get(node_id).map(|n| n.clone())
278 }
279
280 pub fn list_nodes(&self) -> Vec<ClusterNode> {
282 self.nodes
283 .iter()
284 .map(|entry| entry.value().clone())
285 .collect()
286 }
287
288 pub fn healthy_nodes(&self) -> Vec<ClusterNode> {
290 self.nodes
291 .iter()
292 .filter(|entry| entry.value().is_healthy(self.config.node_timeout))
293 .map(|entry| entry.value().clone())
294 .collect()
295 }
296
297 pub fn get_shard(&self, shard_id: u32) -> Option<ShardInfo> {
299 self.shards.get(&shard_id).map(|s| s.clone())
300 }
301
302 pub fn list_shards(&self) -> Vec<ShardInfo> {
304 self.shards
305 .iter()
306 .map(|entry| entry.value().clone())
307 .collect()
308 }
309
310 pub fn assign_shard(&self, shard_id: u32) -> Result<ShardInfo> {
312 let ring = self.hash_ring.read();
313 let key = format!("shard:{}", shard_id);
314
315 let nodes = ring.get_nodes(&key, self.config.replication_factor);
316
317 if nodes.is_empty() {
318 return Err(ClusterError::InvalidConfig(
319 "No nodes available for shard assignment".to_string(),
320 ));
321 }
322
323 let primary_node = nodes[0].clone();
324 let replica_nodes = nodes.into_iter().skip(1).collect();
325
326 let shard_info = ShardInfo {
327 shard_id,
328 primary_node,
329 replica_nodes,
330 vector_count: 0,
331 status: ShardStatus::Active,
332 created_at: Utc::now(),
333 modified_at: Utc::now(),
334 };
335
336 self.shards.insert(shard_id, shard_info.clone());
337 Ok(shard_info)
338 }
339
340 async fn rebalance_shards(&self) -> Result<()> {
342 debug!("Rebalancing shards across cluster");
343
344 for shard_id in 0..self.config.shard_count {
345 if let Some(mut shard) = self.shards.get_mut(&shard_id) {
346 let ring = self.hash_ring.read();
347 let key = format!("shard:{}", shard_id);
348 let nodes = ring.get_nodes(&key, self.config.replication_factor);
349
350 if !nodes.is_empty() {
351 shard.primary_node = nodes[0].clone();
352 shard.replica_nodes = nodes.into_iter().skip(1).collect();
353 shard.modified_at = Utc::now();
354 }
355 } else {
356 self.assign_shard(shard_id)?;
358 }
359 }
360
361 debug!("Shard rebalancing complete");
362 Ok(())
363 }
364
365 pub async fn run_health_checks(&self) -> Result<()> {
367 debug!("Running health checks");
368
369 let mut unhealthy_nodes = Vec::new();
370
371 for entry in self.nodes.iter() {
372 let node = entry.value();
373 if !node.is_healthy(self.config.node_timeout) {
374 warn!("Node {} is unhealthy", node.node_id);
375 unhealthy_nodes.push(node.node_id.clone());
376 }
377 }
378
379 for node_id in unhealthy_nodes {
381 if let Some(mut node) = self.nodes.get_mut(&node_id) {
382 node.status = NodeStatus::Offline;
383 }
384 }
385
386 Ok(())
387 }
388
389 pub async fn start(&self) -> Result<()> {
391 info!("Starting cluster manager for node {}", self.node_id);
392
393 let discovered = self.discovery.discover_nodes().await?;
395 for node in discovered {
396 if node.node_id != self.node_id {
397 self.add_node(node).await?;
398 }
399 }
400
401 for shard_id in 0..self.config.shard_count {
403 self.assign_shard(shard_id)?;
404 }
405
406 info!("Cluster manager started successfully");
407 Ok(())
408 }
409
410 pub fn get_stats(&self) -> ClusterStats {
412 let nodes = self.list_nodes();
413 let shards = self.list_shards();
414 let healthy = self.healthy_nodes();
415
416 ClusterStats {
417 total_nodes: nodes.len(),
418 healthy_nodes: healthy.len(),
419 total_shards: shards.len(),
420 active_shards: shards
421 .iter()
422 .filter(|s| s.status == ShardStatus::Active)
423 .count(),
424 total_vectors: shards.iter().map(|s| s.vector_count).sum(),
425 }
426 }
427
428 pub fn router(&self) -> Arc<ShardRouter> {
430 Arc::clone(&self.router)
431 }
432
433 pub fn consensus(&self) -> Option<Arc<DagConsensus>> {
435 self.consensus.as_ref().map(Arc::clone)
436 }
437}
438
439#[derive(Debug, Clone, Serialize, Deserialize)]
441pub struct ClusterStats {
442 pub total_nodes: usize,
443 pub healthy_nodes: usize,
444 pub total_shards: usize,
445 pub active_shards: usize,
446 pub total_vectors: usize,
447}
448
449#[cfg(test)]
450mod tests {
451 use super::*;
452 use std::net::{IpAddr, Ipv4Addr};
453
454 fn create_test_node(id: &str, port: u16) -> ClusterNode {
455 ClusterNode::new(
456 id.to_string(),
457 SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), port),
458 )
459 }
460
461 #[tokio::test]
462 async fn test_cluster_node_creation() {
463 let node = create_test_node("node1", 8000);
464 assert_eq!(node.node_id, "node1");
465 assert_eq!(node.status, NodeStatus::Follower);
466 assert!(node.is_healthy(Duration::from_secs(60)));
467 }
468
469 #[tokio::test]
470 async fn test_cluster_manager_creation() {
471 let config = ClusterConfig::default();
472 let discovery = Box::new(StaticDiscovery::new(vec![]));
473 let manager = ClusterManager::new(config, "test-node".to_string(), discovery);
474 assert!(manager.is_ok());
475 }
476
477 #[tokio::test]
478 async fn test_add_remove_node() {
479 let config = ClusterConfig::default();
480 let discovery = Box::new(StaticDiscovery::new(vec![]));
481 let manager = ClusterManager::new(config, "test-node".to_string(), discovery).unwrap();
482
483 let node = create_test_node("node1", 8000);
484 manager.add_node(node).await.unwrap();
485
486 assert_eq!(manager.list_nodes().len(), 1);
487
488 manager.remove_node("node1").await.unwrap();
489 assert_eq!(manager.list_nodes().len(), 0);
490 }
491
492 #[tokio::test]
493 async fn test_shard_assignment() {
494 let config = ClusterConfig {
495 shard_count: 4,
496 replication_factor: 2,
497 ..Default::default()
498 };
499 let discovery = Box::new(StaticDiscovery::new(vec![]));
500 let manager = ClusterManager::new(config, "test-node".to_string(), discovery).unwrap();
501
502 for i in 0..3 {
504 let node = create_test_node(&format!("node{}", i), 8000 + i);
505 manager.add_node(node).await.unwrap();
506 }
507
508 let shard = manager.assign_shard(0).unwrap();
510 assert_eq!(shard.shard_id, 0);
511 assert!(!shard.primary_node.is_empty());
512 }
513}