1use crate::common::IntegrateFloat;
8use crate::error::{IntegrateError, IntegrateResult};
9use scirs2_core::ndarray::Array1;
10use std::collections::HashMap;
11use std::net::SocketAddr;
12use std::sync::Arc;
13use std::time::{Duration, Instant};
14
15#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
17pub struct NodeId(pub u64);
18
19impl NodeId {
20 pub fn new(id: u64) -> Self {
22 Self(id)
23 }
24
25 pub fn value(&self) -> u64 {
27 self.0
28 }
29}
30
31impl std::fmt::Display for NodeId {
32 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
33 write!(f, "Node({})", self.0)
34 }
35}
36
37#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
39pub struct JobId(pub u64);
40
41impl JobId {
42 pub fn new(id: u64) -> Self {
44 Self(id)
45 }
46
47 pub fn value(&self) -> u64 {
49 self.0
50 }
51}
52
53#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
55pub struct ChunkId(pub u64);
56
57impl ChunkId {
58 pub fn new(id: u64) -> Self {
60 Self(id)
61 }
62
63 pub fn value(&self) -> u64 {
65 self.0
66 }
67}
68
69#[derive(Debug, Clone, Copy, PartialEq, Eq)]
71pub enum NodeStatus {
72 Available,
74 Busy,
76 Failed,
78 Maintenance,
80 Initializing,
82 ShuttingDown,
84}
85
86#[derive(Debug, Clone)]
88pub struct NodeCapabilities {
89 pub cpu_cores: usize,
91 pub memory_bytes: usize,
93 pub has_gpu: bool,
95 pub gpu_memory_bytes: Option<usize>,
97 pub network_bandwidth: usize,
99 pub latency_us: u64,
101 pub supported_precisions: Vec<FloatPrecision>,
103 pub simd_capabilities: SimdCapability,
105}
106
107impl Default for NodeCapabilities {
108 fn default() -> Self {
109 Self {
110 cpu_cores: 1,
111 memory_bytes: 1024 * 1024 * 1024, has_gpu: false,
113 gpu_memory_bytes: None,
114 network_bandwidth: 100 * 1024 * 1024, latency_us: 1000, supported_precisions: vec![FloatPrecision::F32, FloatPrecision::F64],
117 simd_capabilities: SimdCapability::default(),
118 }
119 }
120}
121
122#[derive(Debug, Clone, Copy, PartialEq, Eq)]
124pub enum FloatPrecision {
125 F16,
127 F32,
129 F64,
131}
132
133#[derive(Debug, Clone, Default)]
135pub struct SimdCapability {
136 pub has_sse: bool,
138 pub has_sse2: bool,
140 pub has_avx: bool,
142 pub has_avx2: bool,
144 pub has_avx512: bool,
146 pub has_neon: bool,
148}
149
150#[derive(Debug, Clone)]
152pub struct NodeInfo {
153 pub id: NodeId,
155 pub address: SocketAddr,
157 pub status: NodeStatus,
159 pub capabilities: NodeCapabilities,
161 pub last_heartbeat: Instant,
163 pub jobs_completed: usize,
165 pub average_job_duration: Duration,
167}
168
169impl NodeInfo {
170 pub fn new(id: NodeId, address: SocketAddr) -> Self {
172 Self {
173 id,
174 address,
175 status: NodeStatus::Initializing,
176 capabilities: NodeCapabilities::default(),
177 last_heartbeat: Instant::now(),
178 jobs_completed: 0,
179 average_job_duration: Duration::ZERO,
180 }
181 }
182
183 pub fn is_healthy(&self, timeout: Duration) -> bool {
185 self.last_heartbeat.elapsed() < timeout
186 && self.status != NodeStatus::Failed
187 && self.status != NodeStatus::ShuttingDown
188 }
189
190 pub fn processing_score(&self) -> f64 {
192 let base_score = self.capabilities.cpu_cores as f64;
193 let gpu_bonus = if self.capabilities.has_gpu { 10.0 } else { 0.0 };
194 let latency_penalty = (self.capabilities.latency_us as f64 / 1000.0).min(5.0);
195
196 base_score + gpu_bonus - latency_penalty
197 }
198}
199
200#[derive(Debug, Clone)]
202pub struct WorkChunk<F: IntegrateFloat> {
203 pub id: ChunkId,
205 pub job_id: JobId,
207 pub time_interval: (F, F),
209 pub initial_state: Array1<F>,
211 pub boundary_conditions: BoundaryConditions<F>,
213 pub priority: u32,
215 pub estimated_cost: f64,
217 pub retry_count: u32,
219 pub max_retries: u32,
221}
222
223impl<F: IntegrateFloat> WorkChunk<F> {
224 pub fn new(
226 id: ChunkId,
227 job_id: JobId,
228 time_interval: (F, F),
229 initial_state: Array1<F>,
230 ) -> Self {
231 let estimated_cost = Self::estimate_cost(&time_interval, initial_state.len());
232 Self {
233 id,
234 job_id,
235 time_interval,
236 initial_state,
237 boundary_conditions: BoundaryConditions::default(),
238 priority: 0,
239 estimated_cost,
240 retry_count: 0,
241 max_retries: 3,
242 }
243 }
244
245 fn estimate_cost(time_interval: &(F, F), state_size: usize) -> f64 {
247 let dt = (time_interval.1 - time_interval.0).to_f64().unwrap_or(1.0);
248 dt * state_size as f64
249 }
250
251 pub fn can_retry(&self) -> bool {
253 self.retry_count < self.max_retries
254 }
255
256 pub fn increment_retry(&mut self) {
258 self.retry_count += 1;
259 }
260}
261
262#[derive(Debug, Clone)]
264pub struct BoundaryConditions<F: IntegrateFloat> {
265 pub left_boundary: Option<BoundaryData<F>>,
267 pub right_boundary: Option<BoundaryData<F>>,
269 pub ghost_cells: Vec<F>,
271 pub coupling_data: HashMap<String, Array1<F>>,
273}
274
275impl<F: IntegrateFloat> Default for BoundaryConditions<F> {
276 fn default() -> Self {
277 Self {
278 left_boundary: None,
279 right_boundary: None,
280 ghost_cells: Vec::new(),
281 coupling_data: HashMap::new(),
282 }
283 }
284}
285
286#[derive(Debug, Clone)]
288pub struct BoundaryData<F: IntegrateFloat> {
289 pub time: F,
291 pub state: Array1<F>,
293 pub derivative: Option<Array1<F>>,
295 pub source_chunk: ChunkId,
297}
298
299#[derive(Debug, Clone)]
301pub struct ChunkResult<F: IntegrateFloat> {
302 pub chunk_id: ChunkId,
304 pub node_id: NodeId,
306 pub time_points: Vec<F>,
308 pub states: Vec<Array1<F>>,
310 pub final_state: Array1<F>,
312 pub final_derivative: Option<Array1<F>>,
314 pub error_estimate: F,
316 pub processing_time: Duration,
318 pub memory_used: usize,
320 pub status: ChunkResultStatus,
322}
323
324#[derive(Debug, Clone, Copy, PartialEq, Eq)]
326pub enum ChunkResultStatus {
327 Success,
329 Failed,
331 NeedsRefinement,
333 Cancelled,
335}
336
337#[derive(Debug, Clone)]
339pub struct DistributedConfig<F: IntegrateFloat> {
340 pub min_chunk_size: F,
342 pub max_chunk_size: F,
344 pub chunks_per_node: usize,
346 pub tolerance: F,
348 pub max_iterations: usize,
350 pub checkpointing_enabled: bool,
352 pub checkpoint_interval: usize,
354 pub communication_timeout: Duration,
356 pub heartbeat_interval: Duration,
358 pub max_retries: u32,
360 pub load_balancing: LoadBalancingStrategy,
362 pub fault_tolerance: FaultToleranceMode,
364}
365
366impl<F: IntegrateFloat> Default for DistributedConfig<F> {
367 fn default() -> Self {
368 Self {
369 min_chunk_size: F::from(0.001).unwrap_or(F::epsilon()),
370 max_chunk_size: F::from(1.0).unwrap_or(F::one()),
371 chunks_per_node: 4,
372 tolerance: F::from(1e-6).unwrap_or(F::epsilon()),
373 max_iterations: 1000,
374 checkpointing_enabled: true,
375 checkpoint_interval: 10,
376 communication_timeout: Duration::from_secs(30),
377 heartbeat_interval: Duration::from_secs(5),
378 max_retries: 3,
379 load_balancing: LoadBalancingStrategy::Adaptive,
380 fault_tolerance: FaultToleranceMode::Standard,
381 }
382 }
383}
384
385#[derive(Debug, Clone, Copy, PartialEq, Eq)]
387pub enum LoadBalancingStrategy {
388 RoundRobin,
390 CapabilityBased,
392 WorkStealing,
394 Adaptive,
396 LocalityAware,
398}
399
400#[derive(Debug, Clone, Copy, PartialEq, Eq)]
402pub enum FaultToleranceMode {
403 None,
405 Standard,
407 HighAvailability,
409 CheckpointRecovery,
411}
412
413#[derive(Debug, Clone)]
415pub enum DistributedMessage<F: IntegrateFloat> {
416 Heartbeat {
418 node_id: NodeId,
419 status: NodeStatus,
420 timestamp: u64,
421 },
422 WorkAssignment {
424 chunk: WorkChunk<F>,
425 deadline: Option<Duration>,
426 },
427 WorkResult { result: ChunkResult<F> },
429 BoundaryExchange {
431 source_chunk: ChunkId,
432 target_chunk: ChunkId,
433 boundary_data: BoundaryData<F>,
434 },
435 CheckpointRequest { job_id: JobId, checkpoint_id: u64 },
437 CheckpointData {
439 job_id: JobId,
440 checkpoint_id: u64,
441 node_id: NodeId,
442 data: Vec<u8>,
443 },
444 NodeRegister {
446 node_id: NodeId,
447 address: SocketAddr,
448 capabilities: NodeCapabilities,
449 },
450 NodeDeregister { node_id: NodeId, reason: String },
452 JobCancel { job_id: JobId, reason: String },
454 SyncBarrier { barrier_id: u64, node_id: NodeId },
456 Ack { message_id: u64, status: AckStatus },
458}
459
460#[derive(Debug, Clone, Copy, PartialEq, Eq)]
462pub enum AckStatus {
463 Ok,
465 Error,
467 Unknown,
469}
470
471#[derive(Debug, Clone, Default)]
473pub struct DistributedMetrics {
474 pub chunks_processed: usize,
476 pub chunks_failed: usize,
478 pub chunks_retried: usize,
480 pub total_processing_time: Duration,
482 pub total_communication_time: Duration,
484 pub average_chunk_time: Duration,
486 pub load_balance_efficiency: f64,
488 pub bytes_sent: usize,
490 pub bytes_received: usize,
492 pub checkpoints_created: usize,
494 pub recoveries: usize,
496}
497
498impl DistributedMetrics {
499 pub fn update_load_balance(&mut self, node_loads: &[f64]) {
501 if node_loads.is_empty() {
502 self.load_balance_efficiency = 1.0;
503 return;
504 }
505
506 let mean_load: f64 = node_loads.iter().sum::<f64>() / node_loads.len() as f64;
507 if mean_load <= 0.0 {
508 self.load_balance_efficiency = 1.0;
509 return;
510 }
511
512 let variance: f64 = node_loads
513 .iter()
514 .map(|&load| (load - mean_load).powi(2))
515 .sum::<f64>()
516 / node_loads.len() as f64;
517
518 let cv = variance.sqrt() / mean_load; self.load_balance_efficiency = (1.0 - cv.min(1.0)).max(0.0);
520 }
521}
522
523#[derive(Debug, Clone)]
525pub enum DistributedError {
526 CommunicationError(String),
528 NodeTimeout(NodeId),
530 NodeFailure(NodeId, String),
532 ChunkError(ChunkId, String),
534 SyncError(String),
536 CheckpointError(String),
538 ConfigError(String),
540 ResourceExhausted(String),
542}
543
544impl std::fmt::Display for DistributedError {
545 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
546 match self {
547 Self::CommunicationError(msg) => write!(f, "Communication error: {}", msg),
548 Self::NodeTimeout(id) => write!(f, "Node {} timed out", id),
549 Self::NodeFailure(id, msg) => write!(f, "Node {} failed: {}", id, msg),
550 Self::ChunkError(id, msg) => write!(f, "Chunk {:?} error: {}", id, msg),
551 Self::SyncError(msg) => write!(f, "Synchronization error: {}", msg),
552 Self::CheckpointError(msg) => write!(f, "Checkpoint error: {}", msg),
553 Self::ConfigError(msg) => write!(f, "Configuration error: {}", msg),
554 Self::ResourceExhausted(msg) => write!(f, "Resource exhausted: {}", msg),
555 }
556 }
557}
558
559impl std::error::Error for DistributedError {}
560
561impl From<DistributedError> for IntegrateError {
562 fn from(err: DistributedError) -> Self {
563 IntegrateError::ComputationError(err.to_string())
564 }
565}
566
567pub type DistributedResult<T> = std::result::Result<T, DistributedError>;
569
570#[cfg(test)]
571mod tests {
572 use super::*;
573 use std::net::{IpAddr, Ipv4Addr};
574
575 #[test]
576 fn test_node_id_display() {
577 let id = NodeId::new(42);
578 assert_eq!(format!("{}", id), "Node(42)");
579 }
580
581 #[test]
582 fn test_node_info_health_check() {
583 let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 8080);
584 let mut node = NodeInfo::new(NodeId::new(1), addr);
585 node.status = NodeStatus::Available;
586
587 assert!(node.is_healthy(Duration::from_secs(60)));
588
589 node.last_heartbeat = Instant::now() - Duration::from_secs(120);
591 assert!(!node.is_healthy(Duration::from_secs(60)));
592 }
593
594 #[test]
595 fn test_work_chunk_retry() {
596 let chunk: WorkChunk<f64> =
597 WorkChunk::new(ChunkId::new(1), JobId::new(1), (0.0, 1.0), Array1::zeros(3));
598
599 assert!(chunk.can_retry());
600 let mut chunk = chunk;
601 for _ in 0..3 {
602 chunk.increment_retry();
603 }
604 assert!(!chunk.can_retry());
605 }
606
607 #[test]
608 fn test_distributed_metrics_load_balance() {
609 let mut metrics = DistributedMetrics::default();
610
611 metrics.update_load_balance(&[1.0, 1.0, 1.0, 1.0]);
613 assert!((metrics.load_balance_efficiency - 1.0).abs() < 0.01);
614
615 metrics.update_load_balance(&[0.1, 0.1, 0.1, 3.7]);
617 assert!(metrics.load_balance_efficiency < 0.5);
618 }
619}