1use crate::common::IntegrateFloat;
8use crate::distributed::types::{
9 AckStatus, BoundaryConditions, BoundaryData, ChunkId, ChunkResult, DistributedError,
10 DistributedMessage, DistributedResult, JobId, NodeCapabilities, NodeId, NodeStatus, WorkChunk,
11};
12use scirs2_core::ndarray::Array1;
13use std::collections::{HashMap, VecDeque};
14use std::sync::atomic::{AtomicU64, Ordering};
15use std::sync::{Arc, Condvar, Mutex, RwLock};
16use std::time::{Duration, Instant};
17
18pub struct MessageChannel<F: IntegrateFloat> {
20 outbox: Mutex<VecDeque<(NodeId, DistributedMessage<F>)>>,
22 inbox: Mutex<VecDeque<(NodeId, DistributedMessage<F>)>>,
24 next_message_id: AtomicU64,
26 pending_acks: Mutex<HashMap<u64, (Instant, NodeId)>>,
28 ack_timeout: Duration,
30 inbox_cv: Condvar,
32 inbox_mutex: Mutex<()>,
34}
35
36impl<F: IntegrateFloat> MessageChannel<F> {
37 pub fn new(ack_timeout: Duration) -> Self {
39 Self {
40 outbox: Mutex::new(VecDeque::new()),
41 inbox: Mutex::new(VecDeque::new()),
42 next_message_id: AtomicU64::new(1),
43 pending_acks: Mutex::new(HashMap::new()),
44 ack_timeout,
45 inbox_cv: Condvar::new(),
46 inbox_mutex: Mutex::new(()),
47 }
48 }
49
50 pub fn generate_message_id(&self) -> u64 {
52 self.next_message_id.fetch_add(1, Ordering::SeqCst)
53 }
54
55 pub fn send(&self, target: NodeId, message: DistributedMessage<F>) -> DistributedResult<u64> {
57 let message_id = self.generate_message_id();
58
59 match self.outbox.lock() {
60 Ok(mut outbox) => {
61 outbox.push_back((target, message));
62 }
63 Err(_) => {
64 return Err(DistributedError::CommunicationError(
65 "Failed to acquire outbox lock".to_string(),
66 ))
67 }
68 }
69
70 match self.pending_acks.lock() {
72 Ok(mut pending) => {
73 pending.insert(message_id, (Instant::now(), target));
74 }
75 Err(_) => {
76 return Err(DistributedError::CommunicationError(
77 "Failed to track acknowledgment".to_string(),
78 ))
79 }
80 }
81
82 Ok(message_id)
83 }
84
85 pub fn receive(&self, timeout: Duration) -> Option<(NodeId, DistributedMessage<F>)> {
87 let deadline = Instant::now() + timeout;
88
89 loop {
90 if let Ok(mut inbox) = self.inbox.lock() {
92 if let Some(msg) = inbox.pop_front() {
93 return Some(msg);
94 }
95 }
96
97 let remaining = deadline.saturating_duration_since(Instant::now());
99 if remaining.is_zero() {
100 return None;
101 }
102
103 if let Ok(guard) = self.inbox_mutex.lock() {
104 let _ = self.inbox_cv.wait_timeout(guard, remaining);
105 }
106 }
107 }
108
109 pub fn try_receive(&self) -> Option<(NodeId, DistributedMessage<F>)> {
111 match self.inbox.lock() {
112 Ok(mut inbox) => inbox.pop_front(),
113 Err(_) => None,
114 }
115 }
116
117 pub fn deliver(&self, source: NodeId, message: DistributedMessage<F>) -> DistributedResult<()> {
119 match self.inbox.lock() {
120 Ok(mut inbox) => {
121 inbox.push_back((source, message));
122 self.inbox_cv.notify_one();
124 Ok(())
125 }
126 Err(_) => Err(DistributedError::CommunicationError(
127 "Failed to acquire inbox lock".to_string(),
128 )),
129 }
130 }
131
132 pub fn process_ack(&self, message_id: u64, status: AckStatus) -> DistributedResult<()> {
134 match self.pending_acks.lock() {
135 Ok(mut pending) => {
136 if pending.remove(&message_id).is_some() {
137 if status == AckStatus::Error {
138 return Err(DistributedError::CommunicationError(
139 "Message processing failed at remote node".to_string(),
140 ));
141 }
142 Ok(())
143 } else {
144 Ok(())
146 }
147 }
148 Err(_) => Err(DistributedError::CommunicationError(
149 "Failed to process acknowledgment".to_string(),
150 )),
151 }
152 }
153
154 pub fn check_timeouts(&self) -> Vec<(u64, NodeId)> {
156 match self.pending_acks.lock() {
157 Ok(mut pending) => {
158 let now = Instant::now();
159 let timed_out: Vec<_> = pending
160 .iter()
161 .filter(|(_, (sent_at, _))| now.duration_since(*sent_at) > self.ack_timeout)
162 .map(|(id, (_, node))| (*id, *node))
163 .collect();
164
165 for (id, _) in &timed_out {
166 pending.remove(id);
167 }
168
169 timed_out
170 }
171 Err(_) => Vec::new(),
172 }
173 }
174
175 pub fn outbox_size(&self) -> usize {
177 self.outbox.lock().map(|o| o.len()).unwrap_or(0)
178 }
179
180 pub fn inbox_size(&self) -> usize {
182 self.inbox.lock().map(|i| i.len()).unwrap_or(0)
183 }
184
185 pub fn drain_outbox(&self) -> Vec<(NodeId, DistributedMessage<F>)> {
187 match self.outbox.lock() {
188 Ok(mut outbox) => outbox.drain(..).collect(),
189 Err(_) => Vec::new(),
190 }
191 }
192}
193
194pub struct BoundaryExchanger<F: IntegrateFloat> {
196 received_boundaries: RwLock<HashMap<(ChunkId, ChunkId), BoundaryData<F>>>,
198 pending_requests: Mutex<HashMap<(ChunkId, ChunkId), Instant>>,
200 timeout: Duration,
202}
203
204impl<F: IntegrateFloat> BoundaryExchanger<F> {
205 pub fn new(timeout: Duration) -> Self {
207 Self {
208 received_boundaries: RwLock::new(HashMap::new()),
209 pending_requests: Mutex::new(HashMap::new()),
210 timeout,
211 }
212 }
213
214 pub fn request_boundary(
216 &self,
217 target_chunk: ChunkId,
218 source_chunk: ChunkId,
219 ) -> DistributedResult<()> {
220 match self.pending_requests.lock() {
221 Ok(mut pending) => {
222 pending.insert((target_chunk, source_chunk), Instant::now());
223 Ok(())
224 }
225 Err(_) => Err(DistributedError::CommunicationError(
226 "Failed to register boundary request".to_string(),
227 )),
228 }
229 }
230
231 pub fn receive_boundary(
233 &self,
234 target_chunk: ChunkId,
235 source_chunk: ChunkId,
236 data: BoundaryData<F>,
237 ) -> DistributedResult<()> {
238 match self.received_boundaries.write() {
239 Ok(mut boundaries) => {
240 boundaries.insert((target_chunk, source_chunk), data);
241
242 if let Ok(mut pending) = self.pending_requests.lock() {
244 pending.remove(&(target_chunk, source_chunk));
245 }
246
247 Ok(())
248 }
249 Err(_) => Err(DistributedError::CommunicationError(
250 "Failed to store boundary data".to_string(),
251 )),
252 }
253 }
254
255 pub fn get_boundary(
257 &self,
258 target_chunk: ChunkId,
259 source_chunk: ChunkId,
260 ) -> Option<BoundaryData<F>> {
261 match self.received_boundaries.read() {
262 Ok(boundaries) => boundaries.get(&(target_chunk, source_chunk)).cloned(),
263 Err(_) => None,
264 }
265 }
266
267 pub fn build_boundary_conditions(
269 &self,
270 chunk_id: ChunkId,
271 left_neighbor: Option<ChunkId>,
272 right_neighbor: Option<ChunkId>,
273 ) -> BoundaryConditions<F> {
274 let mut bc = BoundaryConditions::default();
275
276 if let Some(left_id) = left_neighbor {
277 bc.left_boundary = self.get_boundary(chunk_id, left_id);
278 }
279
280 if let Some(right_id) = right_neighbor {
281 bc.right_boundary = self.get_boundary(chunk_id, right_id);
282 }
283
284 bc
285 }
286
287 pub fn check_timeouts(&self) -> Vec<(ChunkId, ChunkId)> {
289 match self.pending_requests.lock() {
290 Ok(mut pending) => {
291 let now = Instant::now();
292 let timed_out: Vec<_> = pending
293 .iter()
294 .filter(|(_, sent_at)| now.duration_since(**sent_at) > self.timeout)
295 .map(|(key, _)| *key)
296 .collect();
297
298 for key in &timed_out {
299 pending.remove(key);
300 }
301
302 timed_out
303 }
304 Err(_) => Vec::new(),
305 }
306 }
307
308 pub fn clear(&self) {
310 if let Ok(mut boundaries) = self.received_boundaries.write() {
311 boundaries.clear();
312 }
313 if let Ok(mut pending) = self.pending_requests.lock() {
314 pending.clear();
315 }
316 }
317}
318
319pub struct SyncBarrier {
321 barrier_id: AtomicU64,
323 expected_count: usize,
325 state: Mutex<BarrierState>,
327 cv: Condvar,
329}
330
331struct BarrierState {
333 current_id: u64,
335 arrived: Vec<NodeId>,
337 released: bool,
339}
340
341impl SyncBarrier {
342 pub fn new(expected_count: usize) -> Self {
344 Self {
345 barrier_id: AtomicU64::new(1),
346 expected_count,
347 state: Mutex::new(BarrierState {
348 current_id: 1,
349 arrived: Vec::new(),
350 released: false,
351 }),
352 cv: Condvar::new(),
353 }
354 }
355
356 pub fn new_barrier(&self) -> u64 {
358 let new_id = self.barrier_id.fetch_add(1, Ordering::SeqCst);
359
360 if let Ok(mut state) = self.state.lock() {
361 state.current_id = new_id;
362 state.arrived.clear();
363 state.released = false;
364 }
365
366 new_id
367 }
368
369 pub fn arrive(&self, barrier_id: u64, node_id: NodeId) -> DistributedResult<()> {
371 let mut state = self
372 .state
373 .lock()
374 .map_err(|_| DistributedError::SyncError("Failed to acquire barrier lock".into()))?;
375
376 if state.current_id != barrier_id {
377 return Err(DistributedError::SyncError(format!(
378 "Barrier ID mismatch: expected {}, got {}",
379 state.current_id, barrier_id
380 )));
381 }
382
383 if !state.arrived.contains(&node_id) {
384 state.arrived.push(node_id);
385 }
386
387 if state.arrived.len() >= self.expected_count {
389 state.released = true;
390 self.cv.notify_all();
391 }
392
393 Ok(())
394 }
395
396 pub fn wait(&self, barrier_id: u64, timeout: Duration) -> DistributedResult<()> {
398 let deadline = Instant::now() + timeout;
399
400 let mut state = self
401 .state
402 .lock()
403 .map_err(|_| DistributedError::SyncError("Failed to acquire barrier lock".into()))?;
404
405 while !state.released && state.current_id == barrier_id {
406 let remaining = deadline.saturating_duration_since(Instant::now());
407 if remaining.is_zero() {
408 return Err(DistributedError::SyncError(
409 "Barrier wait timeout".to_string(),
410 ));
411 }
412
413 let (new_state, result) = self.cv.wait_timeout(state, remaining).map_err(|_| {
414 DistributedError::SyncError("Failed to wait on barrier".to_string())
415 })?;
416
417 state = new_state;
418
419 if result.timed_out() && !state.released {
420 return Err(DistributedError::SyncError(
421 "Barrier wait timeout".to_string(),
422 ));
423 }
424 }
425
426 Ok(())
427 }
428
429 pub fn is_complete(&self, barrier_id: u64) -> bool {
431 match self.state.lock() {
432 Ok(state) => state.current_id == barrier_id && state.released,
433 Err(_) => false,
434 }
435 }
436
437 pub fn arrived_count(&self) -> usize {
439 match self.state.lock() {
440 Ok(state) => state.arrived.len(),
441 Err(_) => 0,
442 }
443 }
444}
445
446pub struct Communicator<F: IntegrateFloat> {
448 local_node_id: NodeId,
450 channel: Arc<MessageChannel<F>>,
452 boundary_exchanger: Arc<BoundaryExchanger<F>>,
454 barriers: RwLock<HashMap<u64, Arc<SyncBarrier>>>,
456 peers: RwLock<Vec<NodeId>>,
458}
459
460impl<F: IntegrateFloat> Communicator<F> {
461 pub fn new(
463 local_node_id: NodeId,
464 channel: Arc<MessageChannel<F>>,
465 boundary_exchanger: Arc<BoundaryExchanger<F>>,
466 ) -> Self {
467 Self {
468 local_node_id,
469 channel,
470 boundary_exchanger,
471 barriers: RwLock::new(HashMap::new()),
472 peers: RwLock::new(Vec::new()),
473 }
474 }
475
476 pub fn local_id(&self) -> NodeId {
478 self.local_node_id
479 }
480
481 pub fn add_peer(&self, node_id: NodeId) -> DistributedResult<()> {
483 match self.peers.write() {
484 Ok(mut peers) => {
485 if !peers.contains(&node_id) {
486 peers.push(node_id);
487 }
488 Ok(())
489 }
490 Err(_) => Err(DistributedError::CommunicationError(
491 "Failed to add peer".to_string(),
492 )),
493 }
494 }
495
496 pub fn remove_peer(&self, node_id: NodeId) -> DistributedResult<()> {
498 match self.peers.write() {
499 Ok(mut peers) => {
500 peers.retain(|&id| id != node_id);
501 Ok(())
502 }
503 Err(_) => Err(DistributedError::CommunicationError(
504 "Failed to remove peer".to_string(),
505 )),
506 }
507 }
508
509 pub fn get_peers(&self) -> Vec<NodeId> {
511 match self.peers.read() {
512 Ok(peers) => peers.clone(),
513 Err(_) => Vec::new(),
514 }
515 }
516
517 pub fn send_work(
519 &self,
520 target: NodeId,
521 chunk: WorkChunk<F>,
522 deadline: Option<Duration>,
523 ) -> DistributedResult<u64> {
524 let message = DistributedMessage::WorkAssignment { chunk, deadline };
525 self.channel.send(target, message)
526 }
527
528 pub fn send_result(&self, target: NodeId, result: ChunkResult<F>) -> DistributedResult<u64> {
530 let message = DistributedMessage::WorkResult { result };
531 self.channel.send(target, message)
532 }
533
534 pub fn send_boundary(
536 &self,
537 target: NodeId,
538 source_chunk: ChunkId,
539 target_chunk: ChunkId,
540 boundary_data: BoundaryData<F>,
541 ) -> DistributedResult<u64> {
542 let message = DistributedMessage::BoundaryExchange {
543 source_chunk,
544 target_chunk,
545 boundary_data,
546 };
547 self.channel.send(target, message)
548 }
549
550 pub fn broadcast(&self, message: DistributedMessage<F>) -> DistributedResult<Vec<u64>> {
552 let peers = self.get_peers();
553 let mut message_ids = Vec::with_capacity(peers.len());
554
555 for peer in peers {
556 let id = self.channel.send(peer, message.clone())?;
557 message_ids.push(id);
558 }
559
560 Ok(message_ids)
561 }
562
563 pub fn create_barrier(&self, expected_count: usize) -> DistributedResult<u64> {
565 let barrier = Arc::new(SyncBarrier::new(expected_count));
566 let barrier_id = barrier.new_barrier();
567
568 match self.barriers.write() {
569 Ok(mut barriers) => {
570 barriers.insert(barrier_id, barrier);
571 Ok(barrier_id)
572 }
573 Err(_) => Err(DistributedError::SyncError(
574 "Failed to create barrier".to_string(),
575 )),
576 }
577 }
578
579 pub fn barrier(&self, barrier_id: u64, timeout: Duration) -> DistributedResult<()> {
581 let barrier = {
583 match self.barriers.read() {
584 Ok(barriers) => barriers.get(&barrier_id).cloned(),
585 Err(_) => None,
586 }
587 };
588
589 let barrier = barrier.ok_or_else(|| {
590 DistributedError::SyncError(format!("Barrier {} not found", barrier_id))
591 })?;
592
593 barrier.arrive(barrier_id, self.local_node_id)?;
595
596 let message = DistributedMessage::SyncBarrier {
598 barrier_id,
599 node_id: self.local_node_id,
600 };
601 let _ = self.broadcast(message);
602
603 barrier.wait(barrier_id, timeout)
605 }
606
607 pub fn process_barrier_message(
609 &self,
610 barrier_id: u64,
611 node_id: NodeId,
612 ) -> DistributedResult<()> {
613 match self.barriers.read() {
614 Ok(barriers) => {
615 if let Some(barrier) = barriers.get(&barrier_id) {
616 barrier.arrive(barrier_id, node_id)?;
617 }
618 Ok(())
619 }
620 Err(_) => Err(DistributedError::SyncError(
621 "Failed to process barrier message".to_string(),
622 )),
623 }
624 }
625
626 pub fn receive_boundary(
628 &self,
629 target_chunk: ChunkId,
630 source_chunk: ChunkId,
631 data: BoundaryData<F>,
632 ) -> DistributedResult<()> {
633 self.boundary_exchanger
634 .receive_boundary(target_chunk, source_chunk, data)
635 }
636
637 pub fn get_boundary_conditions(
639 &self,
640 chunk_id: ChunkId,
641 left_neighbor: Option<ChunkId>,
642 right_neighbor: Option<ChunkId>,
643 ) -> BoundaryConditions<F> {
644 self.boundary_exchanger
645 .build_boundary_conditions(chunk_id, left_neighbor, right_neighbor)
646 }
647}
648
649pub fn serialize_boundary_data<F: IntegrateFloat>(data: &BoundaryData<F>) -> Vec<u8> {
651 let mut bytes = Vec::new();
653
654 let time_f64 = data.time.to_f64().unwrap_or(0.0);
656 bytes.extend_from_slice(&time_f64.to_le_bytes());
657
658 let state_len = data.state.len() as u64;
660 bytes.extend_from_slice(&state_len.to_le_bytes());
661 for val in data.state.iter() {
662 let val_f64 = val.to_f64().unwrap_or(0.0);
663 bytes.extend_from_slice(&val_f64.to_le_bytes());
664 }
665
666 bytes.extend_from_slice(&data.source_chunk.0.to_le_bytes());
668
669 bytes
670}
671
672pub fn deserialize_boundary_data<F: IntegrateFloat>(
674 bytes: &[u8],
675) -> DistributedResult<BoundaryData<F>> {
676 if bytes.len() < 16 {
677 return Err(DistributedError::CommunicationError(
678 "Insufficient data for boundary deserialization".to_string(),
679 ));
680 }
681
682 let mut offset = 0;
683
684 let time_bytes: [u8; 8] = bytes[offset..offset + 8]
686 .try_into()
687 .map_err(|_| DistributedError::CommunicationError("Invalid time bytes".to_string()))?;
688 let time_f64 = f64::from_le_bytes(time_bytes);
689 let time = F::from(time_f64).ok_or_else(|| {
690 DistributedError::CommunicationError("Failed to convert time".to_string())
691 })?;
692 offset += 8;
693
694 let len_bytes: [u8; 8] = bytes[offset..offset + 8]
696 .try_into()
697 .map_err(|_| DistributedError::CommunicationError("Invalid length bytes".to_string()))?;
698 let state_len = u64::from_le_bytes(len_bytes) as usize;
699 offset += 8;
700
701 if bytes.len() < offset + state_len * 8 + 8 {
703 return Err(DistributedError::CommunicationError(
704 "Insufficient data for state values".to_string(),
705 ));
706 }
707
708 let mut state = Array1::zeros(state_len);
709 for i in 0..state_len {
710 let val_bytes: [u8; 8] = bytes[offset..offset + 8]
711 .try_into()
712 .map_err(|_| DistributedError::CommunicationError("Invalid value bytes".to_string()))?;
713 let val_f64 = f64::from_le_bytes(val_bytes);
714 state[i] = F::from(val_f64).ok_or_else(|| {
715 DistributedError::CommunicationError("Failed to convert value".to_string())
716 })?;
717 offset += 8;
718 }
719
720 let chunk_bytes: [u8; 8] = bytes[offset..offset + 8]
722 .try_into()
723 .map_err(|_| DistributedError::CommunicationError("Invalid chunk ID bytes".to_string()))?;
724 let source_chunk = ChunkId(u64::from_le_bytes(chunk_bytes));
725
726 Ok(BoundaryData {
727 time,
728 state,
729 derivative: None,
730 source_chunk,
731 })
732}
733
734#[cfg(test)]
735mod tests {
736 use super::*;
737
738 #[test]
739 fn test_message_channel() {
740 let channel: MessageChannel<f64> = MessageChannel::new(Duration::from_secs(5));
741
742 let node_id = NodeId::new(1);
743 let message = DistributedMessage::Heartbeat {
744 node_id,
745 status: NodeStatus::Available,
746 timestamp: 12345,
747 };
748
749 let msg_id = channel.send(node_id, message.clone());
750 assert!(msg_id.is_ok());
751
752 channel.deliver(node_id, message).expect("Delivery failed");
754
755 let received = channel.try_receive();
757 assert!(received.is_some());
758 }
759
760 #[test]
761 fn test_boundary_exchanger() {
762 let exchanger: BoundaryExchanger<f64> = BoundaryExchanger::new(Duration::from_secs(5));
763
764 let target = ChunkId::new(1);
765 let source = ChunkId::new(0);
766
767 let data = BoundaryData {
768 time: 1.0,
769 state: Array1::from_vec(vec![1.0, 2.0, 3.0]),
770 derivative: None,
771 source_chunk: source,
772 };
773
774 exchanger
775 .receive_boundary(target, source, data)
776 .expect("Failed to receive boundary");
777
778 let retrieved = exchanger.get_boundary(target, source);
779 assert!(retrieved.is_some());
780 assert_eq!(retrieved.map(|b| b.time), Some(1.0));
781 }
782
783 #[test]
784 fn test_sync_barrier() {
785 let barrier = SyncBarrier::new(2);
786 let barrier_id = barrier.new_barrier();
787
788 barrier
790 .arrive(barrier_id, NodeId::new(1))
791 .expect("Failed to arrive");
792 assert!(!barrier.is_complete(barrier_id));
793
794 barrier
796 .arrive(barrier_id, NodeId::new(2))
797 .expect("Failed to arrive");
798 assert!(barrier.is_complete(barrier_id));
799 }
800
801 #[test]
802 fn test_boundary_serialization() {
803 let data = BoundaryData {
804 time: 1.5,
805 state: Array1::from_vec(vec![1.0, 2.0, 3.0]),
806 derivative: None,
807 source_chunk: ChunkId::new(42),
808 };
809
810 let bytes = serialize_boundary_data(&data);
811 let deserialized: BoundaryData<f64> =
812 deserialize_boundary_data(&bytes).expect("Deserialization failed");
813
814 assert!((deserialized.time - data.time).abs() < 1e-10);
815 assert_eq!(deserialized.state.len(), data.state.len());
816 assert_eq!(deserialized.source_chunk.0, data.source_chunk.0);
817 }
818}