scirs2_cluster/distributed/
message_passing.rs

1//! Message passing system for distributed clustering coordination
2//!
3//! This module provides the messaging infrastructure for coordinating
4//! distributed clustering operations across multiple worker nodes.
5
6use scirs2_core::ndarray::{Array1, Array2};
7use scirs2_core::numeric::Float;
8use std::collections::{HashMap, HashSet, VecDeque};
9use std::fmt::Debug;
10use std::sync::mpsc::{self, Receiver, Sender};
11use std::sync::{Arc, Mutex};
12use std::time::{Duration, Instant};
13
14use serde::{Deserialize, Serialize};
15
16use crate::error::{ClusteringError, Result};
17
18/// Message types for distributed clustering coordination
19#[derive(Debug, Clone, Serialize, Deserialize)]
20pub enum ClusteringMessage<F: Float> {
21    /// Initialize worker with partition data
22    InitializeWorker {
23        workerid: usize,
24        partition_data: Array2<F>,
25        initial_centroids: Array2<F>,
26    },
27    /// Update global centroids
28    UpdateCentroids { round: usize, centroids: Array2<F> },
29    /// Request local computation
30    ComputeLocal { round: usize, max_iterations: usize },
31    /// Local computation result
32    LocalResult {
33        workerid: usize,
34        round: usize,
35        local_centroids: Array2<F>,
36        local_labels: Array1<usize>,
37        local_inertia: f64,
38        computation_time_ms: u64,
39    },
40    /// Heartbeat for health monitoring
41    Heartbeat {
42        workerid: usize,
43        timestamp: u64,
44        cpu_usage: f64,
45        memory_usage: f64,
46    },
47    /// Synchronization barrier
48    SyncBarrier {
49        round: usize,
50        participant_count: usize,
51    },
52    /// Convergence check result
53    ConvergenceCheck {
54        round: usize,
55        converged: bool,
56        max_centroid_movement: f64,
57    },
58    /// Terminate worker
59    Terminate,
60    /// Checkpoint creation request
61    CreateCheckpoint { round: usize },
62    /// Checkpoint data
63    CheckpointData {
64        workerid: usize,
65        round: usize,
66        centroids: Array2<F>,
67        labels: Array1<usize>,
68    },
69    /// Recovery request
70    RecoveryRequest {
71        failed_workerid: usize,
72        recovery_strategy: RecoveryStrategy,
73    },
74    /// Load balancing request
75    LoadBalance {
76        target_worker_loads: HashMap<usize, f64>,
77    },
78    /// Data migration for load balancing
79    MigrateData {
80        source_worker: usize,
81        target_worker: usize,
82        data_subset: Array2<F>,
83    },
84    /// Acknowledgment message
85    Acknowledgment { workerid: usize, message_id: u64 },
86}
87
88/// Recovery strategies for failed workers
89#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
90pub enum RecoveryStrategy {
91    /// Redistribute failed worker's data to other workers
92    Redistribute,
93    /// Replace failed worker with a new one
94    Replace,
95    /// Restore from checkpoint
96    Checkpoint,
97    /// Restart entire computation
98    Restart,
99    /// Continue with degraded performance
100    Degrade,
101}
102
103/// Message priority levels
104#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
105pub enum MessagePriority {
106    Critical = 0, // Immediate processing required
107    High = 1,     // High priority
108    Normal = 2,   // Normal processing
109    Low = 3,      // Background processing
110}
111
112/// Message envelope with metadata
113#[derive(Debug, Clone)]
114pub struct MessageEnvelope<F: Float> {
115    pub message_id: u64,
116    pub sender_id: usize,
117    pub receiver_id: usize,
118    pub priority: MessagePriority,
119    pub timestamp: u64,
120    pub retry_count: u32,
121    pub timeout_ms: u64,
122    pub message: ClusteringMessage<F>,
123}
124
125/// Message passing coordinator for distributed clustering
126#[derive(Debug)]
127pub struct MessagePassingCoordinator<F: Float> {
128    pub coordinator_id: usize,
129    pub worker_channels: HashMap<usize, Sender<MessageEnvelope<F>>>,
130    pub coordinator_receiver: Receiver<MessageEnvelope<F>>,
131    pub coordinator_sender: Sender<MessageEnvelope<F>>,
132    pub message_counter: Arc<Mutex<u64>>,
133    pub pending_messages: HashMap<u64, MessageEnvelope<F>>,
134    pub message_timeouts: HashMap<u64, Instant>,
135    pub worker_status: HashMap<usize, WorkerStatus>,
136    pub sync_barriers: HashMap<usize, SynchronizationBarrier>,
137    pub config: MessagePassingConfig,
138}
139
140/// Worker status for health monitoring
141#[derive(Debug, Clone, Copy, PartialEq)]
142pub enum WorkerStatus {
143    Active,
144    Inactive,
145    Failed,
146    Recovering,
147}
148
149/// Configuration for message passing system
150#[derive(Debug, Clone)]
151pub struct MessagePassingConfig {
152    pub max_message_queue_size: usize,
153    pub message_timeout_ms: u64,
154    pub max_retry_attempts: u32,
155    pub heartbeat_interval_ms: u64,
156    pub sync_timeout_ms: u64,
157    pub enable_message_compression: bool,
158    pub enable_message_ordering: bool,
159    pub batch_size: usize,
160}
161
162impl Default for MessagePassingConfig {
163    fn default() -> Self {
164        Self {
165            max_message_queue_size: 1000,
166            message_timeout_ms: 30000,
167            max_retry_attempts: 3,
168            heartbeat_interval_ms: 5000,
169            sync_timeout_ms: 60000,
170            enable_message_compression: false,
171            enable_message_ordering: true,
172            batch_size: 10,
173        }
174    }
175}
176
177/// Synchronization barrier for coordinating worker phases
178#[derive(Debug)]
179pub struct SynchronizationBarrier {
180    pub round: usize,
181    pub expected_participants: usize,
182    pub arrived_participants: HashSet<usize>,
183    pub barrier_start_time: Instant,
184    pub timeout_ms: u64,
185}
186
187impl<F: Float + Debug + Send + Sync + 'static> MessagePassingCoordinator<F> {
188    /// Create new message passing coordinator
189    pub fn new(coordinatorid: usize, config: MessagePassingConfig) -> Self {
190        let (coordinator_sender, coordinator_receiver) = mpsc::channel();
191
192        Self {
193            coordinator_id: coordinatorid,
194            worker_channels: HashMap::new(),
195            coordinator_receiver,
196            coordinator_sender,
197            message_counter: Arc::new(Mutex::new(0)),
198            pending_messages: HashMap::new(),
199            message_timeouts: HashMap::new(),
200            worker_status: HashMap::new(),
201            sync_barriers: HashMap::new(),
202            config,
203        }
204    }
205
206    /// Register a new worker with the coordinator
207    pub fn register_worker(&mut self, workerid: usize) -> Receiver<MessageEnvelope<F>> {
208        let (sender, receiver) = mpsc::channel();
209        self.worker_channels.insert(workerid, sender);
210        self.worker_status.insert(workerid, WorkerStatus::Active);
211        receiver
212    }
213
214    /// Send message to a specific worker
215    pub fn send_message_to_worker(
216        &mut self,
217        workerid: usize,
218        message: ClusteringMessage<F>,
219        priority: MessagePriority,
220    ) -> Result<u64> {
221        let message_id = {
222            let mut counter = self.message_counter.lock().unwrap();
223            *counter += 1;
224            *counter
225        };
226
227        let envelope = MessageEnvelope {
228            message_id,
229            sender_id: self.coordinator_id,
230            receiver_id: workerid,
231            priority,
232            timestamp: std::time::SystemTime::now()
233                .duration_since(std::time::UNIX_EPOCH)
234                .unwrap_or_default()
235                .as_millis() as u64,
236            retry_count: 0,
237            timeout_ms: self.config.message_timeout_ms,
238            message,
239        };
240
241        if let Some(sender) = self.worker_channels.get(&workerid) {
242            sender.send(envelope.clone()).map_err(|_| {
243                ClusteringError::InvalidInput(format!("Worker {} unavailable", workerid))
244            })?;
245
246            self.pending_messages.insert(message_id, envelope);
247            self.message_timeouts.insert(message_id, Instant::now());
248            Ok(message_id)
249        } else {
250            Err(ClusteringError::InvalidInput(format!(
251                "Worker {} not registered",
252                workerid
253            )))
254        }
255    }
256
257    /// Broadcast message to all workers
258    pub fn broadcast_message(
259        &mut self,
260        message: ClusteringMessage<F>,
261        priority: MessagePriority,
262    ) -> Result<Vec<u64>> {
263        let workerids: Vec<usize> = self.worker_channels.keys().copied().collect();
264        let mut message_ids = Vec::new();
265
266        for workerid in workerids {
267            let message_id = self.send_message_to_worker(workerid, message.clone(), priority)?;
268            message_ids.push(message_id);
269        }
270
271        Ok(message_ids)
272    }
273
274    /// Process incoming messages from workers
275    pub fn process_messages(&mut self, timeout: Duration) -> Result<Vec<MessageEnvelope<F>>> {
276        let mut messages = Vec::new();
277        let deadline = Instant::now() + timeout;
278
279        while Instant::now() < deadline {
280            match self.coordinator_receiver.try_recv() {
281                Ok(envelope) => {
282                    messages.push(envelope);
283                }
284                Err(std::sync::mpsc::TryRecvError::Empty) => {
285                    // No more messages available
286                    break;
287                }
288                Err(std::sync::mpsc::TryRecvError::Disconnected) => {
289                    return Err(ClusteringError::InvalidInput(
290                        "Coordinator channel disconnected".to_string(),
291                    ));
292                }
293            }
294        }
295
296        // Clean up timed-out messages
297        self.cleanup_timed_out_messages();
298
299        Ok(messages)
300    }
301
302    /// Create synchronization barrier
303    pub fn create_sync_barrier(
304        &mut self,
305        round: usize,
306        expected_participants: usize,
307    ) -> Result<()> {
308        let barrier = SynchronizationBarrier {
309            round,
310            expected_participants,
311            arrived_participants: HashSet::new(),
312            barrier_start_time: Instant::now(),
313            timeout_ms: self.config.sync_timeout_ms,
314        };
315
316        self.sync_barriers.insert(round, barrier);
317        Ok(())
318    }
319
320    /// Wait for workers to reach synchronization barrier
321    pub fn wait_for_barrier(&mut self, round: usize) -> Result<bool> {
322        if let Some(barrier) = self.sync_barriers.get_mut(&round) {
323            let timeout_reached =
324                barrier.barrier_start_time.elapsed().as_millis() as u64 > barrier.timeout_ms;
325
326            if timeout_reached {
327                // Remove timed-out barrier
328                self.sync_barriers.remove(&round);
329                return Ok(false);
330            }
331
332            let all_arrived = barrier.arrived_participants.len() >= barrier.expected_participants;
333            if all_arrived {
334                self.sync_barriers.remove(&round);
335                Ok(true)
336            } else {
337                Ok(false)
338            }
339        } else {
340            Err(ClusteringError::InvalidInput(format!(
341                "Sync barrier for round {} not found",
342                round
343            )))
344        }
345    }
346
347    /// Register worker arrival at synchronization barrier
348    pub fn register_barrier_arrival(&mut self, round: usize, workerid: usize) -> Result<()> {
349        if let Some(barrier) = self.sync_barriers.get_mut(&round) {
350            barrier.arrived_participants.insert(workerid);
351            Ok(())
352        } else {
353            Err(ClusteringError::InvalidInput(format!(
354                "Sync barrier for round {} not found",
355                round
356            )))
357        }
358    }
359
360    /// Clean up timed-out messages and retry failed sends
361    fn cleanup_timed_out_messages(&mut self) {
362        let now = Instant::now();
363        let timeout_duration = Duration::from_millis(self.config.message_timeout_ms);
364
365        let mut timed_out_messages = Vec::new();
366
367        for (&message_id, &send_time) in &self.message_timeouts {
368            if now.duration_since(send_time) > timeout_duration {
369                timed_out_messages.push(message_id);
370            }
371        }
372
373        for message_id in timed_out_messages {
374            if let Some(envelope) = self.pending_messages.remove(&message_id) {
375                self.message_timeouts.remove(&message_id);
376
377                // Retry if under retry limit
378                if envelope.retry_count < self.config.max_retry_attempts {
379                    let mut retry_envelope = envelope;
380                    retry_envelope.retry_count += 1;
381
382                    if let Some(sender) = self.worker_channels.get(&retry_envelope.receiver_id) {
383                        let _ = sender.send(retry_envelope);
384                    }
385                } else {
386                    // Mark worker as failed after max retries
387                    self.worker_status
388                        .insert(envelope.receiver_id, WorkerStatus::Failed);
389                }
390            }
391        }
392    }
393
394    /// Get worker status
395    pub fn get_worker_status(&self, workerid: usize) -> Option<WorkerStatus> {
396        self.worker_status.get(&workerid).copied()
397    }
398
399    /// Update worker status
400    pub fn update_worker_status(&mut self, workerid: usize, status: WorkerStatus) {
401        self.worker_status.insert(workerid, status);
402    }
403
404    /// Get active workers
405    pub fn get_active_workers(&self) -> Vec<usize> {
406        self.worker_status
407            .iter()
408            .filter(|(_, &status)| status == WorkerStatus::Active)
409            .map(|(&id, _)| id)
410            .collect()
411    }
412
413    /// Get failed workers
414    pub fn get_failed_workers(&self) -> Vec<usize> {
415        self.worker_status
416            .iter()
417            .filter(|(_, &status)| status == WorkerStatus::Failed)
418            .map(|(&id, _)| id)
419            .collect()
420    }
421
422    /// Shutdown coordinator and all worker channels
423    pub fn shutdown(&mut self) {
424        // Send terminate message to all workers
425        let _ = self.broadcast_message(ClusteringMessage::Terminate, MessagePriority::Critical);
426
427        // Clear all state
428        self.worker_channels.clear();
429        self.pending_messages.clear();
430        self.message_timeouts.clear();
431        self.worker_status.clear();
432        self.sync_barriers.clear();
433    }
434}
435
436impl SynchronizationBarrier {
437    /// Check if barrier is complete
438    pub fn is_complete(&self) -> bool {
439        self.arrived_participants.len() >= self.expected_participants
440    }
441
442    /// Check if barrier has timed out
443    pub fn is_timed_out(&self) -> bool {
444        self.barrier_start_time.elapsed().as_millis() as u64 > self.timeout_ms
445    }
446
447    /// Get completion percentage
448    pub fn completion_percentage(&self) -> f64 {
449        self.arrived_participants.len() as f64 / self.expected_participants as f64
450    }
451}
452
453#[cfg(test)]
454mod tests {
455    use super::*;
456    use approx::assert_relative_eq;
457
458    #[test]
459    fn test_message_passing_coordinator_creation() {
460        let config = MessagePassingConfig::default();
461        let coordinator = MessagePassingCoordinator::<f64>::new(0, config);
462
463        assert_eq!(coordinator.coordinator_id, 0);
464        assert!(coordinator.worker_channels.is_empty());
465        assert!(coordinator.pending_messages.is_empty());
466    }
467
468    #[test]
469    fn test_worker_registration() {
470        let config = MessagePassingConfig::default();
471        let mut coordinator = MessagePassingCoordinator::<f64>::new(0, config);
472
473        let _receiver = coordinator.register_worker(1);
474        assert!(coordinator.worker_channels.contains_key(&1));
475        assert_eq!(coordinator.get_worker_status(1), Some(WorkerStatus::Active));
476    }
477
478    #[test]
479    fn test_sync_barrier_creation() {
480        let config = MessagePassingConfig::default();
481        let mut coordinator = MessagePassingCoordinator::<f64>::new(0, config);
482
483        let result = coordinator.create_sync_barrier(1, 3);
484        assert!(result.is_ok());
485        assert!(coordinator.sync_barriers.contains_key(&1));
486    }
487
488    #[test]
489    fn test_sync_barrier_completion() {
490        let mut barrier = SynchronizationBarrier {
491            round: 1,
492            expected_participants: 2,
493            arrived_participants: HashSet::new(),
494            barrier_start_time: Instant::now(),
495            timeout_ms: 1000,
496        };
497
498        assert!(!barrier.is_complete());
499        assert_relative_eq!(barrier.completion_percentage(), 0.0);
500
501        barrier.arrived_participants.insert(1);
502        assert_relative_eq!(barrier.completion_percentage(), 0.5);
503
504        barrier.arrived_participants.insert(2);
505        assert!(barrier.is_complete());
506        assert_relative_eq!(barrier.completion_percentage(), 1.0);
507    }
508}