ruvector_cluster/
discovery.rs

1//! Node discovery mechanisms for cluster formation
2//!
3//! Supports static configuration and gossip-based discovery.
4
5use crate::{ClusterError, ClusterNode, NodeStatus, Result};
6use async_trait::async_trait;
7use chrono::Utc;
8use dashmap::DashMap;
9use parking_lot::RwLock;
10use serde::{Deserialize, Serialize};
11use std::collections::HashMap;
12use std::net::SocketAddr;
13use std::sync::Arc;
14use std::time::Duration;
15use tokio::time;
16use tracing::{debug, info, warn};
17
18/// Service for discovering nodes in the cluster
19#[async_trait]
20pub trait DiscoveryService: Send + Sync {
21    /// Discover nodes in the cluster
22    async fn discover_nodes(&self) -> Result<Vec<ClusterNode>>;
23
24    /// Register this node in the discovery service
25    async fn register_node(&self, node: ClusterNode) -> Result<()>;
26
27    /// Unregister this node from the discovery service
28    async fn unregister_node(&self, node_id: &str) -> Result<()>;
29
30    /// Update node heartbeat
31    async fn heartbeat(&self, node_id: &str) -> Result<()>;
32}
33
34/// Static discovery using predefined node list
35pub struct StaticDiscovery {
36    /// Predefined list of nodes
37    nodes: Arc<RwLock<Vec<ClusterNode>>>,
38}
39
40impl StaticDiscovery {
41    /// Create a new static discovery service
42    pub fn new(nodes: Vec<ClusterNode>) -> Self {
43        Self {
44            nodes: Arc::new(RwLock::new(nodes)),
45        }
46    }
47
48    /// Add a node to the static list
49    pub fn add_node(&self, node: ClusterNode) {
50        let mut nodes = self.nodes.write();
51        nodes.push(node);
52    }
53
54    /// Remove a node from the static list
55    pub fn remove_node(&self, node_id: &str) {
56        let mut nodes = self.nodes.write();
57        nodes.retain(|n| n.node_id != node_id);
58    }
59}
60
61#[async_trait]
62impl DiscoveryService for StaticDiscovery {
63    async fn discover_nodes(&self) -> Result<Vec<ClusterNode>> {
64        let nodes = self.nodes.read();
65        Ok(nodes.clone())
66    }
67
68    async fn register_node(&self, node: ClusterNode) -> Result<()> {
69        self.add_node(node);
70        Ok(())
71    }
72
73    async fn unregister_node(&self, node_id: &str) -> Result<()> {
74        self.remove_node(node_id);
75        Ok(())
76    }
77
78    async fn heartbeat(&self, node_id: &str) -> Result<()> {
79        let mut nodes = self.nodes.write();
80        if let Some(node) = nodes.iter_mut().find(|n| n.node_id == node_id) {
81            node.heartbeat();
82        }
83        Ok(())
84    }
85}
86
87/// Gossip-based discovery protocol
88pub struct GossipDiscovery {
89    /// Local node information
90    local_node: Arc<RwLock<ClusterNode>>,
91    /// Known nodes (node_id -> node)
92    nodes: Arc<DashMap<String, ClusterNode>>,
93    /// Seed nodes to bootstrap gossip
94    seed_nodes: Vec<SocketAddr>,
95    /// Gossip interval
96    gossip_interval: Duration,
97    /// Node timeout
98    node_timeout: Duration,
99}
100
101impl GossipDiscovery {
102    /// Create a new gossip discovery service
103    pub fn new(
104        local_node: ClusterNode,
105        seed_nodes: Vec<SocketAddr>,
106        gossip_interval: Duration,
107        node_timeout: Duration,
108    ) -> Self {
109        let nodes = Arc::new(DashMap::new());
110        nodes.insert(local_node.node_id.clone(), local_node.clone());
111
112        Self {
113            local_node: Arc::new(RwLock::new(local_node)),
114            nodes,
115            seed_nodes,
116            gossip_interval,
117            node_timeout,
118        }
119    }
120
121    /// Start the gossip protocol
122    pub async fn start(&self) -> Result<()> {
123        info!("Starting gossip discovery protocol");
124
125        // Bootstrap from seed nodes
126        self.bootstrap().await?;
127
128        // Start periodic gossip
129        let nodes = Arc::clone(&self.nodes);
130        let gossip_interval = self.gossip_interval;
131
132        tokio::spawn(async move {
133            let mut interval = time::interval(gossip_interval);
134            loop {
135                interval.tick().await;
136                Self::gossip_round(&nodes).await;
137            }
138        });
139
140        Ok(())
141    }
142
143    /// Bootstrap by contacting seed nodes
144    async fn bootstrap(&self) -> Result<()> {
145        debug!("Bootstrapping from {} seed nodes", self.seed_nodes.len());
146
147        for seed_addr in &self.seed_nodes {
148            // In a real implementation, this would contact the seed node
149            // For now, we'll simulate it
150            debug!("Contacting seed node at {}", seed_addr);
151        }
152
153        Ok(())
154    }
155
156    /// Perform a gossip round
157    async fn gossip_round(nodes: &Arc<DashMap<String, ClusterNode>>) {
158        // Select random subset of nodes to gossip with
159        let node_list: Vec<_> = nodes.iter().map(|e| e.value().clone()).collect();
160
161        if node_list.len() < 2 {
162            return;
163        }
164
165        debug!("Gossiping with {} nodes", node_list.len());
166
167        // In a real implementation, we would:
168        // 1. Select random peers
169        // 2. Exchange node lists
170        // 3. Merge received information
171        // 4. Detect failures
172    }
173
174    /// Merge gossip information from another node
175    pub fn merge_gossip(&self, remote_nodes: Vec<ClusterNode>) {
176        for node in remote_nodes {
177            if let Some(mut existing) = self.nodes.get_mut(&node.node_id) {
178                // Update if remote has newer information
179                if node.last_seen > existing.last_seen {
180                    *existing = node;
181                }
182            } else {
183                // Add new node
184                self.nodes.insert(node.node_id.clone(), node);
185            }
186        }
187    }
188
189    /// Remove failed nodes
190    pub fn prune_failed_nodes(&self) {
191        let now = Utc::now();
192        self.nodes.retain(|_, node| {
193            let elapsed = now
194                .signed_duration_since(node.last_seen)
195                .to_std()
196                .unwrap_or(Duration::MAX);
197            elapsed < self.node_timeout
198        });
199    }
200
201    /// Get gossip statistics
202    pub fn get_stats(&self) -> GossipStats {
203        let nodes: Vec<_> = self.nodes.iter().map(|e| e.value().clone()).collect();
204        let healthy = nodes
205            .iter()
206            .filter(|n| n.is_healthy(self.node_timeout))
207            .count();
208
209        GossipStats {
210            total_nodes: nodes.len(),
211            healthy_nodes: healthy,
212            seed_nodes: self.seed_nodes.len(),
213        }
214    }
215}
216
217#[async_trait]
218impl DiscoveryService for GossipDiscovery {
219    async fn discover_nodes(&self) -> Result<Vec<ClusterNode>> {
220        Ok(self.nodes.iter().map(|e| e.value().clone()).collect())
221    }
222
223    async fn register_node(&self, node: ClusterNode) -> Result<()> {
224        self.nodes.insert(node.node_id.clone(), node);
225        Ok(())
226    }
227
228    async fn unregister_node(&self, node_id: &str) -> Result<()> {
229        self.nodes.remove(node_id);
230        Ok(())
231    }
232
233    async fn heartbeat(&self, node_id: &str) -> Result<()> {
234        if let Some(mut node) = self.nodes.get_mut(node_id) {
235            node.heartbeat();
236        }
237        Ok(())
238    }
239}
240
241/// Gossip protocol statistics
242#[derive(Debug, Clone, Serialize, Deserialize)]
243pub struct GossipStats {
244    pub total_nodes: usize,
245    pub healthy_nodes: usize,
246    pub seed_nodes: usize,
247}
248
249/// Multicast-based discovery (for local networks)
250pub struct MulticastDiscovery {
251    /// Local node
252    local_node: ClusterNode,
253    /// Discovered nodes
254    nodes: Arc<DashMap<String, ClusterNode>>,
255    /// Multicast address
256    multicast_addr: String,
257    /// Multicast port
258    multicast_port: u16,
259}
260
261impl MulticastDiscovery {
262    /// Create a new multicast discovery service
263    pub fn new(local_node: ClusterNode, multicast_addr: String, multicast_port: u16) -> Self {
264        Self {
265            local_node,
266            nodes: Arc::new(DashMap::new()),
267            multicast_addr,
268            multicast_port,
269        }
270    }
271
272    /// Start multicast discovery
273    pub async fn start(&self) -> Result<()> {
274        info!(
275            "Starting multicast discovery on {}:{}",
276            self.multicast_addr, self.multicast_port
277        );
278
279        // In a real implementation, this would:
280        // 1. Join multicast group
281        // 2. Send periodic announcements
282        // 3. Listen for other nodes
283        // 4. Update node list
284
285        Ok(())
286    }
287}
288
289#[async_trait]
290impl DiscoveryService for MulticastDiscovery {
291    async fn discover_nodes(&self) -> Result<Vec<ClusterNode>> {
292        Ok(self.nodes.iter().map(|e| e.value().clone()).collect())
293    }
294
295    async fn register_node(&self, node: ClusterNode) -> Result<()> {
296        self.nodes.insert(node.node_id.clone(), node);
297        Ok(())
298    }
299
300    async fn unregister_node(&self, node_id: &str) -> Result<()> {
301        self.nodes.remove(node_id);
302        Ok(())
303    }
304
305    async fn heartbeat(&self, node_id: &str) -> Result<()> {
306        if let Some(mut node) = self.nodes.get_mut(node_id) {
307            node.heartbeat();
308        }
309        Ok(())
310    }
311}
312
313#[cfg(test)]
314mod tests {
315    use super::*;
316    use std::net::{IpAddr, Ipv4Addr};
317
318    fn create_test_node(id: &str, port: u16) -> ClusterNode {
319        ClusterNode::new(
320            id.to_string(),
321            SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), port),
322        )
323    }
324
325    #[tokio::test]
326    async fn test_static_discovery() {
327        let node1 = create_test_node("node1", 8000);
328        let node2 = create_test_node("node2", 8001);
329
330        let discovery = StaticDiscovery::new(vec![node1, node2]);
331
332        let nodes = discovery.discover_nodes().await.unwrap();
333        assert_eq!(nodes.len(), 2);
334    }
335
336    #[tokio::test]
337    async fn test_static_discovery_register() {
338        let discovery = StaticDiscovery::new(vec![]);
339
340        let node = create_test_node("node1", 8000);
341        discovery.register_node(node).await.unwrap();
342
343        let nodes = discovery.discover_nodes().await.unwrap();
344        assert_eq!(nodes.len(), 1);
345    }
346
347    #[tokio::test]
348    async fn test_gossip_discovery() {
349        let local_node = create_test_node("local", 8000);
350        let seed_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 9000);
351
352        let discovery = GossipDiscovery::new(
353            local_node,
354            vec![seed_addr],
355            Duration::from_secs(5),
356            Duration::from_secs(30),
357        );
358
359        let nodes = discovery.discover_nodes().await.unwrap();
360        assert_eq!(nodes.len(), 1); // Only local node initially
361    }
362
363    #[tokio::test]
364    async fn test_gossip_merge() {
365        let local_node = create_test_node("local", 8000);
366        let discovery = GossipDiscovery::new(
367            local_node,
368            vec![],
369            Duration::from_secs(5),
370            Duration::from_secs(30),
371        );
372
373        let remote_nodes = vec![
374            create_test_node("node1", 8001),
375            create_test_node("node2", 8002),
376        ];
377
378        discovery.merge_gossip(remote_nodes);
379
380        let stats = discovery.get_stats();
381        assert_eq!(stats.total_nodes, 3); // local + 2 remote
382    }
383}