scirs2_core/distributed/
communication.rs1use crate::error::{CoreError, CoreResult, ErrorContext};
7use std::collections::HashMap;
8use std::net::SocketAddr;
9use std::sync::{mpsc, Arc, Mutex};
10
11#[derive(Debug, Clone)]
13pub enum DistributedMessage {
14 TaskAssignment { taskid: String, payload: Vec<u8> },
16 Result { taskid: String, result: Vec<u8> },
18 Heartbeat { nodeid: String, timestamp: u64 },
20 Coordination { messagetype: String, data: Vec<u8> },
22 Barrier {
24 barrier_id: String,
25 node_count: usize,
26 },
27}
28
29pub struct CommunicationEndpoint {
31 nodeid: String,
32 address: SocketAddr,
33 message_handlers: Arc<Mutex<HashMap<String, Box<dyn MessageHandler + Send + Sync>>>>,
34 sender: mpsc::Sender<DistributedMessage>,
35 receiver: Arc<Mutex<mpsc::Receiver<DistributedMessage>>>,
36}
37
38impl CommunicationEndpoint {
39 pub fn new(nodeid: String, address: SocketAddr) -> Self {
41 let (sender, receiver) = mpsc::channel();
42
43 Self {
44 nodeid,
45 address,
46 message_handlers: Arc::new(Mutex::new(HashMap::new())),
47 sender,
48 receiver: Arc::new(Mutex::new(receiver)),
49 }
50 }
51
52 pub fn send_message(&self, message: DistributedMessage) -> CoreResult<()> {
54 self.sender.send(message).map_err(|e| {
55 CoreError::CommunicationError(ErrorContext::new(format!("Failed to send message: {e}")))
56 })?;
57 Ok(())
58 }
59
60 pub fn register_handler<H>(&self, messagetype: String, handler: H) -> CoreResult<()>
62 where
63 H: MessageHandler + Send + Sync + 'static,
64 {
65 let mut handlers = self.message_handlers.lock().map_err(|_| {
66 CoreError::InvalidState(ErrorContext::new(
67 "Failed to acquire handlers lock".to_string(),
68 ))
69 })?;
70 handlers.insert(messagetype, Box::new(handler));
71 Ok(())
72 }
73
74 pub fn process_messages(&self) -> CoreResult<()> {
76 let receiver = self.receiver.lock().map_err(|_| {
77 CoreError::InvalidState(ErrorContext::new(
78 "Failed to acquire receiver lock".to_string(),
79 ))
80 })?;
81
82 while let Ok(message) = receiver.try_recv() {
83 self.handle_message(message)?;
84 }
85
86 Ok(())
87 }
88
89 fn handle_message(&self, message: DistributedMessage) -> CoreResult<()> {
90 let handlers = self.message_handlers.lock().map_err(|_| {
91 CoreError::InvalidState(ErrorContext::new(
92 "Failed to acquire handlers lock".to_string(),
93 ))
94 })?;
95
96 match &message {
97 DistributedMessage::TaskAssignment { .. } => {
98 if let Some(handler) = handlers.get("task_assignment") {
99 handler.handle(&message)?;
100 }
101 }
102 DistributedMessage::Result { .. } => {
103 if let Some(handler) = handlers.get("result") {
104 handler.handle(&message)?;
105 }
106 }
107 DistributedMessage::Heartbeat { .. } => {
108 if let Some(handler) = handlers.get("heartbeat") {
109 handler.handle(&message)?;
110 }
111 }
112 DistributedMessage::Coordination { messagetype, .. } => {
113 if let Some(handler) = handlers.get(messagetype) {
114 handler.handle(&message)?;
115 }
116 }
117 DistributedMessage::Barrier { .. } => {
118 if let Some(handler) = handlers.get("barrier") {
119 handler.handle(&message)?;
120 }
121 }
122 }
123
124 Ok(())
125 }
126
127 pub fn nodeid(&self) -> &str {
129 &self.nodeid
130 }
131
132 pub fn address(&self) -> SocketAddr {
134 self.address
135 }
136}
137
138pub trait MessageHandler {
140 fn handle(&self, message: &DistributedMessage) -> CoreResult<()>;
142}
143
144#[derive(Debug)]
146pub struct HeartbeatHandler {
147 #[allow(dead_code)]
148 nodeid: String,
149}
150
151impl HeartbeatHandler {
152 pub fn new(nodeid: String) -> Self {
154 Self { nodeid }
155 }
156}
157
158impl MessageHandler for HeartbeatHandler {
159 fn handle(&self, message: &DistributedMessage) -> CoreResult<()> {
160 if let DistributedMessage::Heartbeat { nodeid, timestamp } = message {
161 println!("Received heartbeat from {nodeid} at {timestamp}");
162 }
163 Ok(())
164 }
165}
166
167pub struct CommunicationManager {
169 endpoints: HashMap<String, CommunicationEndpoint>,
170 local_nodeid: String,
171}
172
173impl CommunicationManager {
174 pub fn new(local_nodeid: String) -> Self {
176 Self {
177 endpoints: HashMap::new(),
178 local_nodeid,
179 }
180 }
181
182 pub fn add_endpoint(&mut self, endpoint: CommunicationEndpoint) {
184 let nodeid = endpoint.nodeid().to_string();
185 self.endpoints.insert(nodeid, endpoint);
186 }
187
188 pub fn broadcast(&self, message: DistributedMessage) -> CoreResult<()> {
190 for endpoint in self.endpoints.values() {
191 endpoint.send_message(message.clone())?;
192 }
193 Ok(())
194 }
195
196 pub fn send_to(&self, nodeid: &str, message: DistributedMessage) -> CoreResult<()> {
198 if let Some(endpoint) = self.endpoints.get(nodeid) {
199 endpoint.send_message(message)?;
200 } else {
201 return Err(CoreError::InvalidArgument(ErrorContext::new(format!(
202 "Unknown node: {nodeid}"
203 ))));
204 }
205 Ok(())
206 }
207
208 pub fn process_all_messages(&self) -> CoreResult<()> {
210 for endpoint in self.endpoints.values() {
211 endpoint.process_messages()?;
212 }
213 Ok(())
214 }
215
216 pub fn local_nodeid(&self) -> &str {
218 &self.local_nodeid
219 }
220}
221
222#[cfg(test)]
223mod tests {
224 use super::*;
225 use std::net::{IpAddr, Ipv4Addr};
226
227 #[test]
228 fn test_communication_endpoint_creation() {
229 let address = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 8080);
230 let endpoint = CommunicationEndpoint::new("node1".to_string(), address);
231
232 assert_eq!(endpoint.nodeid(), "node1");
233 assert_eq!(endpoint.address(), address);
234 }
235
236 #[test]
237 fn test_heartbeat_handler() {
238 let handler = HeartbeatHandler::new("node1".to_string());
239 let message = DistributedMessage::Heartbeat {
240 nodeid: "node2".to_string(),
241 timestamp: 123456789,
242 };
243
244 assert!(handler.handle(&message).is_ok());
245 }
246
247 #[test]
248 fn test_communication_manager() {
249 let mut manager = CommunicationManager::new("local_node".to_string());
250 assert_eq!(manager.local_nodeid(), "local_node");
251
252 let address = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 8080);
253 let endpoint = CommunicationEndpoint::new("node1".to_string(), address);
254 manager.add_endpoint(endpoint);
255
256 assert!(manager.endpoints.contains_key("node1"));
257 }
258}