1use scirs2_core::ndarray::{Array1, Array2};
7use scirs2_core::numeric::Float;
8use std::collections::{HashMap, HashSet, VecDeque};
9use std::fmt::Debug;
10use std::sync::mpsc::{self, Receiver, Sender};
11use std::sync::{Arc, Mutex};
12use std::time::{Duration, Instant};
13
14use serde::{Deserialize, Serialize};
15
16use crate::error::{ClusteringError, Result};
17
18#[derive(Debug, Clone, Serialize, Deserialize)]
20pub enum ClusteringMessage<F: Float> {
21 InitializeWorker {
23 workerid: usize,
24 partition_data: Array2<F>,
25 initial_centroids: Array2<F>,
26 },
27 UpdateCentroids { round: usize, centroids: Array2<F> },
29 ComputeLocal { round: usize, max_iterations: usize },
31 LocalResult {
33 workerid: usize,
34 round: usize,
35 local_centroids: Array2<F>,
36 local_labels: Array1<usize>,
37 local_inertia: f64,
38 computation_time_ms: u64,
39 },
40 Heartbeat {
42 workerid: usize,
43 timestamp: u64,
44 cpu_usage: f64,
45 memory_usage: f64,
46 },
47 SyncBarrier {
49 round: usize,
50 participant_count: usize,
51 },
52 ConvergenceCheck {
54 round: usize,
55 converged: bool,
56 max_centroid_movement: f64,
57 },
58 Terminate,
60 CreateCheckpoint { round: usize },
62 CheckpointData {
64 workerid: usize,
65 round: usize,
66 centroids: Array2<F>,
67 labels: Array1<usize>,
68 },
69 RecoveryRequest {
71 failed_workerid: usize,
72 recovery_strategy: RecoveryStrategy,
73 },
74 LoadBalance {
76 target_worker_loads: HashMap<usize, f64>,
77 },
78 MigrateData {
80 source_worker: usize,
81 target_worker: usize,
82 data_subset: Array2<F>,
83 },
84 Acknowledgment { workerid: usize, message_id: u64 },
86}
87
88#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
90pub enum RecoveryStrategy {
91 Redistribute,
93 Replace,
95 Checkpoint,
97 Restart,
99 Degrade,
101}
102
103#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
105pub enum MessagePriority {
106 Critical = 0, High = 1, Normal = 2, Low = 3, }
111
112#[derive(Debug, Clone)]
114pub struct MessageEnvelope<F: Float> {
115 pub message_id: u64,
116 pub sender_id: usize,
117 pub receiver_id: usize,
118 pub priority: MessagePriority,
119 pub timestamp: u64,
120 pub retry_count: u32,
121 pub timeout_ms: u64,
122 pub message: ClusteringMessage<F>,
123}
124
125#[derive(Debug)]
127pub struct MessagePassingCoordinator<F: Float> {
128 pub coordinator_id: usize,
129 pub worker_channels: HashMap<usize, Sender<MessageEnvelope<F>>>,
130 pub coordinator_receiver: Receiver<MessageEnvelope<F>>,
131 pub coordinator_sender: Sender<MessageEnvelope<F>>,
132 pub message_counter: Arc<Mutex<u64>>,
133 pub pending_messages: HashMap<u64, MessageEnvelope<F>>,
134 pub message_timeouts: HashMap<u64, Instant>,
135 pub worker_status: HashMap<usize, WorkerStatus>,
136 pub sync_barriers: HashMap<usize, SynchronizationBarrier>,
137 pub config: MessagePassingConfig,
138}
139
140#[derive(Debug, Clone, Copy, PartialEq)]
142pub enum WorkerStatus {
143 Active,
144 Inactive,
145 Failed,
146 Recovering,
147}
148
149#[derive(Debug, Clone)]
151pub struct MessagePassingConfig {
152 pub max_message_queue_size: usize,
153 pub message_timeout_ms: u64,
154 pub max_retry_attempts: u32,
155 pub heartbeat_interval_ms: u64,
156 pub sync_timeout_ms: u64,
157 pub enable_message_compression: bool,
158 pub enable_message_ordering: bool,
159 pub batch_size: usize,
160}
161
162impl Default for MessagePassingConfig {
163 fn default() -> Self {
164 Self {
165 max_message_queue_size: 1000,
166 message_timeout_ms: 30000,
167 max_retry_attempts: 3,
168 heartbeat_interval_ms: 5000,
169 sync_timeout_ms: 60000,
170 enable_message_compression: false,
171 enable_message_ordering: true,
172 batch_size: 10,
173 }
174 }
175}
176
177#[derive(Debug)]
179pub struct SynchronizationBarrier {
180 pub round: usize,
181 pub expected_participants: usize,
182 pub arrived_participants: HashSet<usize>,
183 pub barrier_start_time: Instant,
184 pub timeout_ms: u64,
185}
186
187impl<F: Float + Debug + Send + Sync + 'static> MessagePassingCoordinator<F> {
188 pub fn new(coordinatorid: usize, config: MessagePassingConfig) -> Self {
190 let (coordinator_sender, coordinator_receiver) = mpsc::channel();
191
192 Self {
193 coordinator_id: coordinatorid,
194 worker_channels: HashMap::new(),
195 coordinator_receiver,
196 coordinator_sender,
197 message_counter: Arc::new(Mutex::new(0)),
198 pending_messages: HashMap::new(),
199 message_timeouts: HashMap::new(),
200 worker_status: HashMap::new(),
201 sync_barriers: HashMap::new(),
202 config,
203 }
204 }
205
206 pub fn register_worker(&mut self, workerid: usize) -> Receiver<MessageEnvelope<F>> {
208 let (sender, receiver) = mpsc::channel();
209 self.worker_channels.insert(workerid, sender);
210 self.worker_status.insert(workerid, WorkerStatus::Active);
211 receiver
212 }
213
214 pub fn send_message_to_worker(
216 &mut self,
217 workerid: usize,
218 message: ClusteringMessage<F>,
219 priority: MessagePriority,
220 ) -> Result<u64> {
221 let message_id = {
222 let mut counter = self.message_counter.lock().unwrap();
223 *counter += 1;
224 *counter
225 };
226
227 let envelope = MessageEnvelope {
228 message_id,
229 sender_id: self.coordinator_id,
230 receiver_id: workerid,
231 priority,
232 timestamp: std::time::SystemTime::now()
233 .duration_since(std::time::UNIX_EPOCH)
234 .unwrap_or_default()
235 .as_millis() as u64,
236 retry_count: 0,
237 timeout_ms: self.config.message_timeout_ms,
238 message,
239 };
240
241 if let Some(sender) = self.worker_channels.get(&workerid) {
242 sender.send(envelope.clone()).map_err(|_| {
243 ClusteringError::InvalidInput(format!("Worker {} unavailable", workerid))
244 })?;
245
246 self.pending_messages.insert(message_id, envelope);
247 self.message_timeouts.insert(message_id, Instant::now());
248 Ok(message_id)
249 } else {
250 Err(ClusteringError::InvalidInput(format!(
251 "Worker {} not registered",
252 workerid
253 )))
254 }
255 }
256
257 pub fn broadcast_message(
259 &mut self,
260 message: ClusteringMessage<F>,
261 priority: MessagePriority,
262 ) -> Result<Vec<u64>> {
263 let workerids: Vec<usize> = self.worker_channels.keys().copied().collect();
264 let mut message_ids = Vec::new();
265
266 for workerid in workerids {
267 let message_id = self.send_message_to_worker(workerid, message.clone(), priority)?;
268 message_ids.push(message_id);
269 }
270
271 Ok(message_ids)
272 }
273
274 pub fn process_messages(&mut self, timeout: Duration) -> Result<Vec<MessageEnvelope<F>>> {
276 let mut messages = Vec::new();
277 let deadline = Instant::now() + timeout;
278
279 while Instant::now() < deadline {
280 match self.coordinator_receiver.try_recv() {
281 Ok(envelope) => {
282 messages.push(envelope);
283 }
284 Err(std::sync::mpsc::TryRecvError::Empty) => {
285 break;
287 }
288 Err(std::sync::mpsc::TryRecvError::Disconnected) => {
289 return Err(ClusteringError::InvalidInput(
290 "Coordinator channel disconnected".to_string(),
291 ));
292 }
293 }
294 }
295
296 self.cleanup_timed_out_messages();
298
299 Ok(messages)
300 }
301
302 pub fn create_sync_barrier(
304 &mut self,
305 round: usize,
306 expected_participants: usize,
307 ) -> Result<()> {
308 let barrier = SynchronizationBarrier {
309 round,
310 expected_participants,
311 arrived_participants: HashSet::new(),
312 barrier_start_time: Instant::now(),
313 timeout_ms: self.config.sync_timeout_ms,
314 };
315
316 self.sync_barriers.insert(round, barrier);
317 Ok(())
318 }
319
320 pub fn wait_for_barrier(&mut self, round: usize) -> Result<bool> {
322 if let Some(barrier) = self.sync_barriers.get_mut(&round) {
323 let timeout_reached =
324 barrier.barrier_start_time.elapsed().as_millis() as u64 > barrier.timeout_ms;
325
326 if timeout_reached {
327 self.sync_barriers.remove(&round);
329 return Ok(false);
330 }
331
332 let all_arrived = barrier.arrived_participants.len() >= barrier.expected_participants;
333 if all_arrived {
334 self.sync_barriers.remove(&round);
335 Ok(true)
336 } else {
337 Ok(false)
338 }
339 } else {
340 Err(ClusteringError::InvalidInput(format!(
341 "Sync barrier for round {} not found",
342 round
343 )))
344 }
345 }
346
347 pub fn register_barrier_arrival(&mut self, round: usize, workerid: usize) -> Result<()> {
349 if let Some(barrier) = self.sync_barriers.get_mut(&round) {
350 barrier.arrived_participants.insert(workerid);
351 Ok(())
352 } else {
353 Err(ClusteringError::InvalidInput(format!(
354 "Sync barrier for round {} not found",
355 round
356 )))
357 }
358 }
359
360 fn cleanup_timed_out_messages(&mut self) {
362 let now = Instant::now();
363 let timeout_duration = Duration::from_millis(self.config.message_timeout_ms);
364
365 let mut timed_out_messages = Vec::new();
366
367 for (&message_id, &send_time) in &self.message_timeouts {
368 if now.duration_since(send_time) > timeout_duration {
369 timed_out_messages.push(message_id);
370 }
371 }
372
373 for message_id in timed_out_messages {
374 if let Some(envelope) = self.pending_messages.remove(&message_id) {
375 self.message_timeouts.remove(&message_id);
376
377 if envelope.retry_count < self.config.max_retry_attempts {
379 let mut retry_envelope = envelope;
380 retry_envelope.retry_count += 1;
381
382 if let Some(sender) = self.worker_channels.get(&retry_envelope.receiver_id) {
383 let _ = sender.send(retry_envelope);
384 }
385 } else {
386 self.worker_status
388 .insert(envelope.receiver_id, WorkerStatus::Failed);
389 }
390 }
391 }
392 }
393
394 pub fn get_worker_status(&self, workerid: usize) -> Option<WorkerStatus> {
396 self.worker_status.get(&workerid).copied()
397 }
398
399 pub fn update_worker_status(&mut self, workerid: usize, status: WorkerStatus) {
401 self.worker_status.insert(workerid, status);
402 }
403
404 pub fn get_active_workers(&self) -> Vec<usize> {
406 self.worker_status
407 .iter()
408 .filter(|(_, &status)| status == WorkerStatus::Active)
409 .map(|(&id, _)| id)
410 .collect()
411 }
412
413 pub fn get_failed_workers(&self) -> Vec<usize> {
415 self.worker_status
416 .iter()
417 .filter(|(_, &status)| status == WorkerStatus::Failed)
418 .map(|(&id, _)| id)
419 .collect()
420 }
421
422 pub fn shutdown(&mut self) {
424 let _ = self.broadcast_message(ClusteringMessage::Terminate, MessagePriority::Critical);
426
427 self.worker_channels.clear();
429 self.pending_messages.clear();
430 self.message_timeouts.clear();
431 self.worker_status.clear();
432 self.sync_barriers.clear();
433 }
434}
435
436impl SynchronizationBarrier {
437 pub fn is_complete(&self) -> bool {
439 self.arrived_participants.len() >= self.expected_participants
440 }
441
442 pub fn is_timed_out(&self) -> bool {
444 self.barrier_start_time.elapsed().as_millis() as u64 > self.timeout_ms
445 }
446
447 pub fn completion_percentage(&self) -> f64 {
449 self.arrived_participants.len() as f64 / self.expected_participants as f64
450 }
451}
452
453#[cfg(test)]
454mod tests {
455 use super::*;
456 use approx::assert_relative_eq;
457
458 #[test]
459 fn test_message_passing_coordinator_creation() {
460 let config = MessagePassingConfig::default();
461 let coordinator = MessagePassingCoordinator::<f64>::new(0, config);
462
463 assert_eq!(coordinator.coordinator_id, 0);
464 assert!(coordinator.worker_channels.is_empty());
465 assert!(coordinator.pending_messages.is_empty());
466 }
467
468 #[test]
469 fn test_worker_registration() {
470 let config = MessagePassingConfig::default();
471 let mut coordinator = MessagePassingCoordinator::<f64>::new(0, config);
472
473 let _receiver = coordinator.register_worker(1);
474 assert!(coordinator.worker_channels.contains_key(&1));
475 assert_eq!(coordinator.get_worker_status(1), Some(WorkerStatus::Active));
476 }
477
478 #[test]
479 fn test_sync_barrier_creation() {
480 let config = MessagePassingConfig::default();
481 let mut coordinator = MessagePassingCoordinator::<f64>::new(0, config);
482
483 let result = coordinator.create_sync_barrier(1, 3);
484 assert!(result.is_ok());
485 assert!(coordinator.sync_barriers.contains_key(&1));
486 }
487
488 #[test]
489 fn test_sync_barrier_completion() {
490 let mut barrier = SynchronizationBarrier {
491 round: 1,
492 expected_participants: 2,
493 arrived_participants: HashSet::new(),
494 barrier_start_time: Instant::now(),
495 timeout_ms: 1000,
496 };
497
498 assert!(!barrier.is_complete());
499 assert_relative_eq!(barrier.completion_percentage(), 0.0);
500
501 barrier.arrived_participants.insert(1);
502 assert_relative_eq!(barrier.completion_percentage(), 0.5);
503
504 barrier.arrived_participants.insert(2);
505 assert!(barrier.is_complete());
506 assert_relative_eq!(barrier.completion_percentage(), 1.0);
507 }
508}