scirs2_core/distributed/
communication.rs

1//! Distributed communication protocols
2//!
3//! This module provides communication protocols for distributed computing,
4//! including message passing, synchronization, and coordination mechanisms.
5
6use crate::error::{CoreError, CoreResult, ErrorContext};
7use std::collections::HashMap;
8use std::net::SocketAddr;
9use std::sync::{mpsc, Arc, Mutex};
10
11/// Message types for distributed communication
12#[derive(Debug, Clone)]
13pub enum DistributedMessage {
14    /// Task assignment message
15    TaskAssignment { taskid: String, payload: Vec<u8> },
16    /// Result message
17    Result { taskid: String, result: Vec<u8> },
18    /// Heartbeat message
19    Heartbeat { nodeid: String, timestamp: u64 },
20    /// Coordination message
21    Coordination { messagetype: String, data: Vec<u8> },
22    /// Synchronization barrier
23    Barrier {
24        barrier_id: String,
25        node_count: usize,
26    },
27}
28
29/// Communication endpoint for a node
30pub struct CommunicationEndpoint {
31    nodeid: String,
32    address: SocketAddr,
33    message_handlers: Arc<Mutex<HashMap<String, Box<dyn MessageHandler + Send + Sync>>>>,
34    sender: mpsc::Sender<DistributedMessage>,
35    receiver: Arc<Mutex<mpsc::Receiver<DistributedMessage>>>,
36}
37
38impl CommunicationEndpoint {
39    /// Create a new communication endpoint
40    pub fn new(nodeid: String, address: SocketAddr) -> Self {
41        let (sender, receiver) = mpsc::channel();
42
43        Self {
44            nodeid,
45            address,
46            message_handlers: Arc::new(Mutex::new(HashMap::new())),
47            sender,
48            receiver: Arc::new(Mutex::new(receiver)),
49        }
50    }
51
52    /// Send a message to another node
53    pub fn send_message(&self, message: DistributedMessage) -> CoreResult<()> {
54        self.sender.send(message).map_err(|e| {
55            CoreError::CommunicationError(ErrorContext::new(format!("Failed to send message: {e}")))
56        })?;
57        Ok(())
58    }
59
60    /// Register a message handler
61    pub fn register_handler<H>(&self, messagetype: String, handler: H) -> CoreResult<()>
62    where
63        H: MessageHandler + Send + Sync + 'static,
64    {
65        let mut handlers = self.message_handlers.lock().map_err(|_| {
66            CoreError::InvalidState(ErrorContext::new(
67                "Failed to acquire handlers lock".to_string(),
68            ))
69        })?;
70        handlers.insert(messagetype, Box::new(handler));
71        Ok(())
72    }
73
74    /// Process incoming messages
75    pub fn process_messages(&self) -> CoreResult<()> {
76        let receiver = self.receiver.lock().map_err(|_| {
77            CoreError::InvalidState(ErrorContext::new(
78                "Failed to acquire receiver lock".to_string(),
79            ))
80        })?;
81
82        while let Ok(message) = receiver.try_recv() {
83            self.handle_message(message)?;
84        }
85
86        Ok(())
87    }
88
89    fn handle_message(&self, message: DistributedMessage) -> CoreResult<()> {
90        let handlers = self.message_handlers.lock().map_err(|_| {
91            CoreError::InvalidState(ErrorContext::new(
92                "Failed to acquire handlers lock".to_string(),
93            ))
94        })?;
95
96        match &message {
97            DistributedMessage::TaskAssignment { .. } => {
98                if let Some(handler) = handlers.get("task_assignment") {
99                    handler.handle(&message)?;
100                }
101            }
102            DistributedMessage::Result { .. } => {
103                if let Some(handler) = handlers.get("result") {
104                    handler.handle(&message)?;
105                }
106            }
107            DistributedMessage::Heartbeat { .. } => {
108                if let Some(handler) = handlers.get("heartbeat") {
109                    handler.handle(&message)?;
110                }
111            }
112            DistributedMessage::Coordination { messagetype, .. } => {
113                if let Some(handler) = handlers.get(messagetype) {
114                    handler.handle(&message)?;
115                }
116            }
117            DistributedMessage::Barrier { .. } => {
118                if let Some(handler) = handlers.get("barrier") {
119                    handler.handle(&message)?;
120                }
121            }
122        }
123
124        Ok(())
125    }
126
127    /// Get node ID
128    pub fn nodeid(&self) -> &str {
129        &self.nodeid
130    }
131
132    /// Get address
133    pub fn address(&self) -> SocketAddr {
134        self.address
135    }
136}
137
138/// Trait for handling distributed messages
139pub trait MessageHandler {
140    /// Handle a received message
141    fn handle(&self, message: &DistributedMessage) -> CoreResult<()>;
142}
143
144/// Default heartbeat handler
145#[derive(Debug)]
146pub struct HeartbeatHandler {
147    #[allow(dead_code)]
148    nodeid: String,
149}
150
151impl HeartbeatHandler {
152    /// Create a new heartbeat handler
153    pub fn new(nodeid: String) -> Self {
154        Self { nodeid }
155    }
156}
157
158impl MessageHandler for HeartbeatHandler {
159    fn handle(&self, message: &DistributedMessage) -> CoreResult<()> {
160        if let DistributedMessage::Heartbeat { nodeid, timestamp } = message {
161            println!("Received heartbeat from {nodeid} at {timestamp}");
162        }
163        Ok(())
164    }
165}
166
167/// Communication manager for coordinating distributed operations
168pub struct CommunicationManager {
169    endpoints: HashMap<String, CommunicationEndpoint>,
170    local_nodeid: String,
171}
172
173impl CommunicationManager {
174    /// Create a new communication manager
175    pub fn new(local_nodeid: String) -> Self {
176        Self {
177            endpoints: HashMap::new(),
178            local_nodeid,
179        }
180    }
181
182    /// Add a communication endpoint
183    pub fn add_endpoint(&mut self, endpoint: CommunicationEndpoint) {
184        let nodeid = endpoint.nodeid().to_string();
185        self.endpoints.insert(nodeid, endpoint);
186    }
187
188    /// Broadcast a message to all nodes
189    pub fn broadcast(&self, message: DistributedMessage) -> CoreResult<()> {
190        for endpoint in self.endpoints.values() {
191            endpoint.send_message(message.clone())?;
192        }
193        Ok(())
194    }
195
196    /// Send a message to a specific node
197    pub fn send_to(&self, nodeid: &str, message: DistributedMessage) -> CoreResult<()> {
198        if let Some(endpoint) = self.endpoints.get(nodeid) {
199            endpoint.send_message(message)?;
200        } else {
201            return Err(CoreError::InvalidArgument(ErrorContext::new(format!(
202                "Unknown node: {nodeid}"
203            ))));
204        }
205        Ok(())
206    }
207
208    /// Process all pending messages
209    pub fn process_all_messages(&self) -> CoreResult<()> {
210        for endpoint in self.endpoints.values() {
211            endpoint.process_messages()?;
212        }
213        Ok(())
214    }
215
216    /// Get local node ID
217    pub fn local_nodeid(&self) -> &str {
218        &self.local_nodeid
219    }
220}
221
222#[cfg(test)]
223mod tests {
224    use super::*;
225    use std::net::{IpAddr, Ipv4Addr};
226
227    #[test]
228    fn test_communication_endpoint_creation() {
229        let address = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 8080);
230        let endpoint = CommunicationEndpoint::new("node1".to_string(), address);
231
232        assert_eq!(endpoint.nodeid(), "node1");
233        assert_eq!(endpoint.address(), address);
234    }
235
236    #[test]
237    fn test_heartbeat_handler() {
238        let handler = HeartbeatHandler::new("node1".to_string());
239        let message = DistributedMessage::Heartbeat {
240            nodeid: "node2".to_string(),
241            timestamp: 123456789,
242        };
243
244        assert!(handler.handle(&message).is_ok());
245    }
246
247    #[test]
248    fn test_communication_manager() {
249        let mut manager = CommunicationManager::new("local_node".to_string());
250        assert_eq!(manager.local_nodeid(), "local_node");
251
252        let address = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 8080);
253        let endpoint = CommunicationEndpoint::new("node1".to_string(), address);
254        manager.add_endpoint(endpoint);
255
256        assert!(manager.endpoints.contains_key("node1"));
257    }
258}