1use crate::{ClusterError, ClusterNode, NodeStatus, Result};
6use async_trait::async_trait;
7use chrono::Utc;
8use dashmap::DashMap;
9use parking_lot::RwLock;
10use serde::{Deserialize, Serialize};
11use std::collections::HashMap;
12use std::net::SocketAddr;
13use std::sync::Arc;
14use std::time::Duration;
15use tokio::time;
16use tracing::{debug, info, warn};
17
18#[async_trait]
20pub trait DiscoveryService: Send + Sync {
21 async fn discover_nodes(&self) -> Result<Vec<ClusterNode>>;
23
24 async fn register_node(&self, node: ClusterNode) -> Result<()>;
26
27 async fn unregister_node(&self, node_id: &str) -> Result<()>;
29
30 async fn heartbeat(&self, node_id: &str) -> Result<()>;
32}
33
34pub struct StaticDiscovery {
36 nodes: Arc<RwLock<Vec<ClusterNode>>>,
38}
39
40impl StaticDiscovery {
41 pub fn new(nodes: Vec<ClusterNode>) -> Self {
43 Self {
44 nodes: Arc::new(RwLock::new(nodes)),
45 }
46 }
47
48 pub fn add_node(&self, node: ClusterNode) {
50 let mut nodes = self.nodes.write();
51 nodes.push(node);
52 }
53
54 pub fn remove_node(&self, node_id: &str) {
56 let mut nodes = self.nodes.write();
57 nodes.retain(|n| n.node_id != node_id);
58 }
59}
60
61#[async_trait]
62impl DiscoveryService for StaticDiscovery {
63 async fn discover_nodes(&self) -> Result<Vec<ClusterNode>> {
64 let nodes = self.nodes.read();
65 Ok(nodes.clone())
66 }
67
68 async fn register_node(&self, node: ClusterNode) -> Result<()> {
69 self.add_node(node);
70 Ok(())
71 }
72
73 async fn unregister_node(&self, node_id: &str) -> Result<()> {
74 self.remove_node(node_id);
75 Ok(())
76 }
77
78 async fn heartbeat(&self, node_id: &str) -> Result<()> {
79 let mut nodes = self.nodes.write();
80 if let Some(node) = nodes.iter_mut().find(|n| n.node_id == node_id) {
81 node.heartbeat();
82 }
83 Ok(())
84 }
85}
86
87pub struct GossipDiscovery {
89 local_node: Arc<RwLock<ClusterNode>>,
91 nodes: Arc<DashMap<String, ClusterNode>>,
93 seed_nodes: Vec<SocketAddr>,
95 gossip_interval: Duration,
97 node_timeout: Duration,
99}
100
101impl GossipDiscovery {
102 pub fn new(
104 local_node: ClusterNode,
105 seed_nodes: Vec<SocketAddr>,
106 gossip_interval: Duration,
107 node_timeout: Duration,
108 ) -> Self {
109 let nodes = Arc::new(DashMap::new());
110 nodes.insert(local_node.node_id.clone(), local_node.clone());
111
112 Self {
113 local_node: Arc::new(RwLock::new(local_node)),
114 nodes,
115 seed_nodes,
116 gossip_interval,
117 node_timeout,
118 }
119 }
120
121 pub async fn start(&self) -> Result<()> {
123 info!("Starting gossip discovery protocol");
124
125 self.bootstrap().await?;
127
128 let nodes = Arc::clone(&self.nodes);
130 let gossip_interval = self.gossip_interval;
131
132 tokio::spawn(async move {
133 let mut interval = time::interval(gossip_interval);
134 loop {
135 interval.tick().await;
136 Self::gossip_round(&nodes).await;
137 }
138 });
139
140 Ok(())
141 }
142
143 async fn bootstrap(&self) -> Result<()> {
145 debug!("Bootstrapping from {} seed nodes", self.seed_nodes.len());
146
147 for seed_addr in &self.seed_nodes {
148 debug!("Contacting seed node at {}", seed_addr);
151 }
152
153 Ok(())
154 }
155
156 async fn gossip_round(nodes: &Arc<DashMap<String, ClusterNode>>) {
158 let node_list: Vec<_> = nodes.iter().map(|e| e.value().clone()).collect();
160
161 if node_list.len() < 2 {
162 return;
163 }
164
165 debug!("Gossiping with {} nodes", node_list.len());
166
167 }
173
174 pub fn merge_gossip(&self, remote_nodes: Vec<ClusterNode>) {
176 for node in remote_nodes {
177 if let Some(mut existing) = self.nodes.get_mut(&node.node_id) {
178 if node.last_seen > existing.last_seen {
180 *existing = node;
181 }
182 } else {
183 self.nodes.insert(node.node_id.clone(), node);
185 }
186 }
187 }
188
189 pub fn prune_failed_nodes(&self) {
191 let now = Utc::now();
192 self.nodes.retain(|_, node| {
193 let elapsed = now
194 .signed_duration_since(node.last_seen)
195 .to_std()
196 .unwrap_or(Duration::MAX);
197 elapsed < self.node_timeout
198 });
199 }
200
201 pub fn get_stats(&self) -> GossipStats {
203 let nodes: Vec<_> = self.nodes.iter().map(|e| e.value().clone()).collect();
204 let healthy = nodes
205 .iter()
206 .filter(|n| n.is_healthy(self.node_timeout))
207 .count();
208
209 GossipStats {
210 total_nodes: nodes.len(),
211 healthy_nodes: healthy,
212 seed_nodes: self.seed_nodes.len(),
213 }
214 }
215}
216
217#[async_trait]
218impl DiscoveryService for GossipDiscovery {
219 async fn discover_nodes(&self) -> Result<Vec<ClusterNode>> {
220 Ok(self.nodes.iter().map(|e| e.value().clone()).collect())
221 }
222
223 async fn register_node(&self, node: ClusterNode) -> Result<()> {
224 self.nodes.insert(node.node_id.clone(), node);
225 Ok(())
226 }
227
228 async fn unregister_node(&self, node_id: &str) -> Result<()> {
229 self.nodes.remove(node_id);
230 Ok(())
231 }
232
233 async fn heartbeat(&self, node_id: &str) -> Result<()> {
234 if let Some(mut node) = self.nodes.get_mut(node_id) {
235 node.heartbeat();
236 }
237 Ok(())
238 }
239}
240
241#[derive(Debug, Clone, Serialize, Deserialize)]
243pub struct GossipStats {
244 pub total_nodes: usize,
245 pub healthy_nodes: usize,
246 pub seed_nodes: usize,
247}
248
249pub struct MulticastDiscovery {
251 local_node: ClusterNode,
253 nodes: Arc<DashMap<String, ClusterNode>>,
255 multicast_addr: String,
257 multicast_port: u16,
259}
260
261impl MulticastDiscovery {
262 pub fn new(local_node: ClusterNode, multicast_addr: String, multicast_port: u16) -> Self {
264 Self {
265 local_node,
266 nodes: Arc::new(DashMap::new()),
267 multicast_addr,
268 multicast_port,
269 }
270 }
271
272 pub async fn start(&self) -> Result<()> {
274 info!(
275 "Starting multicast discovery on {}:{}",
276 self.multicast_addr, self.multicast_port
277 );
278
279 Ok(())
286 }
287}
288
289#[async_trait]
290impl DiscoveryService for MulticastDiscovery {
291 async fn discover_nodes(&self) -> Result<Vec<ClusterNode>> {
292 Ok(self.nodes.iter().map(|e| e.value().clone()).collect())
293 }
294
295 async fn register_node(&self, node: ClusterNode) -> Result<()> {
296 self.nodes.insert(node.node_id.clone(), node);
297 Ok(())
298 }
299
300 async fn unregister_node(&self, node_id: &str) -> Result<()> {
301 self.nodes.remove(node_id);
302 Ok(())
303 }
304
305 async fn heartbeat(&self, node_id: &str) -> Result<()> {
306 if let Some(mut node) = self.nodes.get_mut(node_id) {
307 node.heartbeat();
308 }
309 Ok(())
310 }
311}
312
313#[cfg(test)]
314mod tests {
315 use super::*;
316 use std::net::{IpAddr, Ipv4Addr};
317
318 fn create_test_node(id: &str, port: u16) -> ClusterNode {
319 ClusterNode::new(
320 id.to_string(),
321 SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), port),
322 )
323 }
324
325 #[tokio::test]
326 async fn test_static_discovery() {
327 let node1 = create_test_node("node1", 8000);
328 let node2 = create_test_node("node2", 8001);
329
330 let discovery = StaticDiscovery::new(vec![node1, node2]);
331
332 let nodes = discovery.discover_nodes().await.unwrap();
333 assert_eq!(nodes.len(), 2);
334 }
335
336 #[tokio::test]
337 async fn test_static_discovery_register() {
338 let discovery = StaticDiscovery::new(vec![]);
339
340 let node = create_test_node("node1", 8000);
341 discovery.register_node(node).await.unwrap();
342
343 let nodes = discovery.discover_nodes().await.unwrap();
344 assert_eq!(nodes.len(), 1);
345 }
346
347 #[tokio::test]
348 async fn test_gossip_discovery() {
349 let local_node = create_test_node("local", 8000);
350 let seed_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 9000);
351
352 let discovery = GossipDiscovery::new(
353 local_node,
354 vec![seed_addr],
355 Duration::from_secs(5),
356 Duration::from_secs(30),
357 );
358
359 let nodes = discovery.discover_nodes().await.unwrap();
360 assert_eq!(nodes.len(), 1); }
362
363 #[tokio::test]
364 async fn test_gossip_merge() {
365 let local_node = create_test_node("local", 8000);
366 let discovery = GossipDiscovery::new(
367 local_node,
368 vec![],
369 Duration::from_secs(5),
370 Duration::from_secs(30),
371 );
372
373 let remote_nodes = vec![
374 create_test_node("node1", 8001),
375 create_test_node("node2", 8002),
376 ];
377
378 discovery.merge_gossip(remote_nodes);
379
380 let stats = discovery.get_stats();
381 assert_eq!(stats.total_nodes, 3); }
383}