Skip to main content

scirs2_integrate/distributed/
types.rs

1//! Core types for distributed computing in numerical integration
2//!
3//! This module provides the fundamental types and abstractions for
4//! distributed computation of ODEs and numerical integration across
5//! multiple compute nodes.
6
7use crate::common::IntegrateFloat;
8use crate::error::{IntegrateError, IntegrateResult};
9use scirs2_core::ndarray::Array1;
10use std::collections::HashMap;
11use std::net::SocketAddr;
12use std::sync::Arc;
13use std::time::{Duration, Instant};
14
15/// Unique identifier for a compute node in the distributed system
16#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
17pub struct NodeId(pub u64);
18
19impl NodeId {
20    /// Create a new node ID
21    pub fn new(id: u64) -> Self {
22        Self(id)
23    }
24
25    /// Get the raw ID value
26    pub fn value(&self) -> u64 {
27        self.0
28    }
29}
30
31impl std::fmt::Display for NodeId {
32    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
33        write!(f, "Node({})", self.0)
34    }
35}
36
37/// Unique identifier for a distributed computation job
38#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
39pub struct JobId(pub u64);
40
41impl JobId {
42    /// Create a new job ID
43    pub fn new(id: u64) -> Self {
44        Self(id)
45    }
46
47    /// Get the raw ID value
48    pub fn value(&self) -> u64 {
49        self.0
50    }
51}
52
53/// Unique identifier for a work chunk
54#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
55pub struct ChunkId(pub u64);
56
57impl ChunkId {
58    /// Create a new chunk ID
59    pub fn new(id: u64) -> Self {
60        Self(id)
61    }
62
63    /// Get the raw ID value
64    pub fn value(&self) -> u64 {
65        self.0
66    }
67}
68
69/// Status of a compute node
70#[derive(Debug, Clone, Copy, PartialEq, Eq)]
71pub enum NodeStatus {
72    /// Node is available for work
73    Available,
74    /// Node is currently processing work
75    Busy,
76    /// Node has failed and is unavailable
77    Failed,
78    /// Node is in maintenance mode
79    Maintenance,
80    /// Node is starting up
81    Initializing,
82    /// Node is shutting down
83    ShuttingDown,
84}
85
86/// Capabilities and resources of a compute node
87#[derive(Debug, Clone)]
88pub struct NodeCapabilities {
89    /// Number of CPU cores available
90    pub cpu_cores: usize,
91    /// Available memory in bytes
92    pub memory_bytes: usize,
93    /// Whether GPU acceleration is available
94    pub has_gpu: bool,
95    /// GPU memory in bytes (if available)
96    pub gpu_memory_bytes: Option<usize>,
97    /// Network bandwidth in bytes per second
98    pub network_bandwidth: usize,
99    /// Latency to coordinator in microseconds
100    pub latency_us: u64,
101    /// Supported floating-point precisions
102    pub supported_precisions: Vec<FloatPrecision>,
103    /// SIMD capabilities
104    pub simd_capabilities: SimdCapability,
105}
106
107impl Default for NodeCapabilities {
108    fn default() -> Self {
109        Self {
110            cpu_cores: 1,
111            memory_bytes: 1024 * 1024 * 1024, // 1 GB
112            has_gpu: false,
113            gpu_memory_bytes: None,
114            network_bandwidth: 100 * 1024 * 1024, // 100 MB/s
115            latency_us: 1000,                     // 1ms
116            supported_precisions: vec![FloatPrecision::F32, FloatPrecision::F64],
117            simd_capabilities: SimdCapability::default(),
118        }
119    }
120}
121
122/// Floating-point precision options
123#[derive(Debug, Clone, Copy, PartialEq, Eq)]
124pub enum FloatPrecision {
125    /// 16-bit floating point (half)
126    F16,
127    /// 32-bit floating point (single)
128    F32,
129    /// 64-bit floating point (double)
130    F64,
131}
132
133/// SIMD capability information
134#[derive(Debug, Clone, Default)]
135pub struct SimdCapability {
136    /// SSE support
137    pub has_sse: bool,
138    /// SSE2 support
139    pub has_sse2: bool,
140    /// AVX support
141    pub has_avx: bool,
142    /// AVX2 support
143    pub has_avx2: bool,
144    /// AVX-512 support
145    pub has_avx512: bool,
146    /// NEON support (ARM)
147    pub has_neon: bool,
148}
149
150/// Information about a compute node
151#[derive(Debug, Clone)]
152pub struct NodeInfo {
153    /// Unique identifier for this node
154    pub id: NodeId,
155    /// Network address of the node
156    pub address: SocketAddr,
157    /// Current status of the node
158    pub status: NodeStatus,
159    /// Node capabilities
160    pub capabilities: NodeCapabilities,
161    /// Last heartbeat timestamp
162    pub last_heartbeat: Instant,
163    /// Number of jobs completed
164    pub jobs_completed: usize,
165    /// Average job duration
166    pub average_job_duration: Duration,
167}
168
169impl NodeInfo {
170    /// Create a new node info with default capabilities
171    pub fn new(id: NodeId, address: SocketAddr) -> Self {
172        Self {
173            id,
174            address,
175            status: NodeStatus::Initializing,
176            capabilities: NodeCapabilities::default(),
177            last_heartbeat: Instant::now(),
178            jobs_completed: 0,
179            average_job_duration: Duration::ZERO,
180        }
181    }
182
183    /// Check if the node is healthy (recent heartbeat)
184    pub fn is_healthy(&self, timeout: Duration) -> bool {
185        self.last_heartbeat.elapsed() < timeout
186            && self.status != NodeStatus::Failed
187            && self.status != NodeStatus::ShuttingDown
188    }
189
190    /// Calculate the node's processing score for load balancing
191    pub fn processing_score(&self) -> f64 {
192        let base_score = self.capabilities.cpu_cores as f64;
193        let gpu_bonus = if self.capabilities.has_gpu { 10.0 } else { 0.0 };
194        let latency_penalty = (self.capabilities.latency_us as f64 / 1000.0).min(5.0);
195
196        base_score + gpu_bonus - latency_penalty
197    }
198}
199
200/// A chunk of work to be processed by a node
201#[derive(Debug, Clone)]
202pub struct WorkChunk<F: IntegrateFloat> {
203    /// Unique identifier for this chunk
204    pub id: ChunkId,
205    /// Job this chunk belongs to
206    pub job_id: JobId,
207    /// Time interval for this chunk [t_start, t_end]
208    pub time_interval: (F, F),
209    /// Initial state for this chunk
210    pub initial_state: Array1<F>,
211    /// Boundary conditions from adjacent chunks
212    pub boundary_conditions: BoundaryConditions<F>,
213    /// Priority level (higher = more urgent)
214    pub priority: u32,
215    /// Estimated computational cost
216    pub estimated_cost: f64,
217    /// Number of retry attempts
218    pub retry_count: u32,
219    /// Maximum allowed retries
220    pub max_retries: u32,
221}
222
223impl<F: IntegrateFloat> WorkChunk<F> {
224    /// Create a new work chunk
225    pub fn new(
226        id: ChunkId,
227        job_id: JobId,
228        time_interval: (F, F),
229        initial_state: Array1<F>,
230    ) -> Self {
231        let estimated_cost = Self::estimate_cost(&time_interval, initial_state.len());
232        Self {
233            id,
234            job_id,
235            time_interval,
236            initial_state,
237            boundary_conditions: BoundaryConditions::default(),
238            priority: 0,
239            estimated_cost,
240            retry_count: 0,
241            max_retries: 3,
242        }
243    }
244
245    /// Estimate computational cost based on interval and state size
246    fn estimate_cost(time_interval: &(F, F), state_size: usize) -> f64 {
247        let dt = (time_interval.1 - time_interval.0).to_f64().unwrap_or(1.0);
248        dt * state_size as f64
249    }
250
251    /// Check if this chunk can be retried
252    pub fn can_retry(&self) -> bool {
253        self.retry_count < self.max_retries
254    }
255
256    /// Increment retry count
257    pub fn increment_retry(&mut self) {
258        self.retry_count += 1;
259    }
260}
261
262/// Boundary conditions for inter-chunk communication
263#[derive(Debug, Clone)]
264pub struct BoundaryConditions<F: IntegrateFloat> {
265    /// Left boundary values (from previous chunk)
266    pub left_boundary: Option<BoundaryData<F>>,
267    /// Right boundary values (from next chunk)
268    pub right_boundary: Option<BoundaryData<F>>,
269    /// Ghost cells for finite difference methods
270    pub ghost_cells: Vec<F>,
271    /// Coupling information for multi-physics
272    pub coupling_data: HashMap<String, Array1<F>>,
273}
274
275impl<F: IntegrateFloat> Default for BoundaryConditions<F> {
276    fn default() -> Self {
277        Self {
278            left_boundary: None,
279            right_boundary: None,
280            ghost_cells: Vec::new(),
281            coupling_data: HashMap::new(),
282        }
283    }
284}
285
286/// Data at a boundary between chunks
287#[derive(Debug, Clone)]
288pub struct BoundaryData<F: IntegrateFloat> {
289    /// Time at which boundary data is valid
290    pub time: F,
291    /// State values at boundary
292    pub state: Array1<F>,
293    /// Derivative values at boundary (for higher-order continuity)
294    pub derivative: Option<Array1<F>>,
295    /// Source chunk ID
296    pub source_chunk: ChunkId,
297}
298
299/// Result of processing a work chunk
300#[derive(Debug, Clone)]
301pub struct ChunkResult<F: IntegrateFloat> {
302    /// Chunk that was processed
303    pub chunk_id: ChunkId,
304    /// Node that processed the chunk
305    pub node_id: NodeId,
306    /// Time points in the solution
307    pub time_points: Vec<F>,
308    /// Solution states at each time point
309    pub states: Vec<Array1<F>>,
310    /// Final state for continuity with next chunk
311    pub final_state: Array1<F>,
312    /// Final derivative for higher-order continuity
313    pub final_derivative: Option<Array1<F>>,
314    /// Error estimate for this chunk
315    pub error_estimate: F,
316    /// Processing duration
317    pub processing_time: Duration,
318    /// Memory usage in bytes
319    pub memory_used: usize,
320    /// Status of the result
321    pub status: ChunkResultStatus,
322}
323
324/// Status of a chunk result
325#[derive(Debug, Clone, Copy, PartialEq, Eq)]
326pub enum ChunkResultStatus {
327    /// Chunk was processed successfully
328    Success,
329    /// Chunk processing failed
330    Failed,
331    /// Chunk needs to be reprocessed (e.g., tolerance not met)
332    NeedsRefinement,
333    /// Chunk was cancelled
334    Cancelled,
335}
336
337/// Configuration for distributed computation
338#[derive(Debug, Clone)]
339pub struct DistributedConfig<F: IntegrateFloat> {
340    /// Minimum chunk size (time interval)
341    pub min_chunk_size: F,
342    /// Maximum chunk size
343    pub max_chunk_size: F,
344    /// Target number of chunks per node
345    pub chunks_per_node: usize,
346    /// Tolerance for solution accuracy
347    pub tolerance: F,
348    /// Maximum iterations for convergence
349    pub max_iterations: usize,
350    /// Enable checkpointing
351    pub checkpointing_enabled: bool,
352    /// Checkpoint interval (number of chunks)
353    pub checkpoint_interval: usize,
354    /// Communication timeout
355    pub communication_timeout: Duration,
356    /// Heartbeat interval
357    pub heartbeat_interval: Duration,
358    /// Maximum retries for failed chunks
359    pub max_retries: u32,
360    /// Load balancing strategy
361    pub load_balancing: LoadBalancingStrategy,
362    /// Fault tolerance mode
363    pub fault_tolerance: FaultToleranceMode,
364}
365
366impl<F: IntegrateFloat> Default for DistributedConfig<F> {
367    fn default() -> Self {
368        Self {
369            min_chunk_size: F::from(0.001).unwrap_or(F::epsilon()),
370            max_chunk_size: F::from(1.0).unwrap_or(F::one()),
371            chunks_per_node: 4,
372            tolerance: F::from(1e-6).unwrap_or(F::epsilon()),
373            max_iterations: 1000,
374            checkpointing_enabled: true,
375            checkpoint_interval: 10,
376            communication_timeout: Duration::from_secs(30),
377            heartbeat_interval: Duration::from_secs(5),
378            max_retries: 3,
379            load_balancing: LoadBalancingStrategy::Adaptive,
380            fault_tolerance: FaultToleranceMode::Standard,
381        }
382    }
383}
384
385/// Load balancing strategies for distributing work
386#[derive(Debug, Clone, Copy, PartialEq, Eq)]
387pub enum LoadBalancingStrategy {
388    /// Round-robin distribution
389    RoundRobin,
390    /// Distribute based on node capabilities
391    CapabilityBased,
392    /// Dynamic work stealing
393    WorkStealing,
394    /// Adaptive strategy that adjusts based on performance
395    Adaptive,
396    /// Minimize communication by keeping related chunks together
397    LocalityAware,
398}
399
400/// Fault tolerance modes
401#[derive(Debug, Clone, Copy, PartialEq, Eq)]
402pub enum FaultToleranceMode {
403    /// No fault tolerance (fastest but risky)
404    None,
405    /// Standard fault tolerance with retries
406    Standard,
407    /// High availability with replication
408    HighAvailability,
409    /// Checkpoint-based recovery
410    CheckpointRecovery,
411}
412
413/// Message types for inter-node communication
414#[derive(Debug, Clone)]
415pub enum DistributedMessage<F: IntegrateFloat> {
416    /// Heartbeat message
417    Heartbeat {
418        node_id: NodeId,
419        status: NodeStatus,
420        timestamp: u64,
421    },
422    /// Work assignment message
423    WorkAssignment {
424        chunk: WorkChunk<F>,
425        deadline: Option<Duration>,
426    },
427    /// Work result message
428    WorkResult { result: ChunkResult<F> },
429    /// Boundary data exchange
430    BoundaryExchange {
431        source_chunk: ChunkId,
432        target_chunk: ChunkId,
433        boundary_data: BoundaryData<F>,
434    },
435    /// Checkpoint request
436    CheckpointRequest { job_id: JobId, checkpoint_id: u64 },
437    /// Checkpoint data
438    CheckpointData {
439        job_id: JobId,
440        checkpoint_id: u64,
441        node_id: NodeId,
442        data: Vec<u8>,
443    },
444    /// Node registration
445    NodeRegister {
446        node_id: NodeId,
447        address: SocketAddr,
448        capabilities: NodeCapabilities,
449    },
450    /// Node deregistration
451    NodeDeregister { node_id: NodeId, reason: String },
452    /// Job cancellation
453    JobCancel { job_id: JobId, reason: String },
454    /// Synchronization barrier
455    SyncBarrier { barrier_id: u64, node_id: NodeId },
456    /// Acknowledgment
457    Ack { message_id: u64, status: AckStatus },
458}
459
460/// Acknowledgment status
461#[derive(Debug, Clone, Copy, PartialEq, Eq)]
462pub enum AckStatus {
463    /// Message received and processed
464    Ok,
465    /// Message received but processing failed
466    Error,
467    /// Message not understood
468    Unknown,
469}
470
471/// Metrics for monitoring distributed computation
472#[derive(Debug, Clone, Default)]
473pub struct DistributedMetrics {
474    /// Total number of chunks processed
475    pub chunks_processed: usize,
476    /// Number of chunks failed
477    pub chunks_failed: usize,
478    /// Number of chunks retried
479    pub chunks_retried: usize,
480    /// Total processing time
481    pub total_processing_time: Duration,
482    /// Total communication time
483    pub total_communication_time: Duration,
484    /// Average chunk processing time
485    pub average_chunk_time: Duration,
486    /// Load balance efficiency (0.0 to 1.0)
487    pub load_balance_efficiency: f64,
488    /// Network bytes sent
489    pub bytes_sent: usize,
490    /// Network bytes received
491    pub bytes_received: usize,
492    /// Number of checkpoints created
493    pub checkpoints_created: usize,
494    /// Number of recoveries from failures
495    pub recoveries: usize,
496}
497
498impl DistributedMetrics {
499    /// Update load balance efficiency
500    pub fn update_load_balance(&mut self, node_loads: &[f64]) {
501        if node_loads.is_empty() {
502            self.load_balance_efficiency = 1.0;
503            return;
504        }
505
506        let mean_load: f64 = node_loads.iter().sum::<f64>() / node_loads.len() as f64;
507        if mean_load <= 0.0 {
508            self.load_balance_efficiency = 1.0;
509            return;
510        }
511
512        let variance: f64 = node_loads
513            .iter()
514            .map(|&load| (load - mean_load).powi(2))
515            .sum::<f64>()
516            / node_loads.len() as f64;
517
518        let cv = variance.sqrt() / mean_load; // Coefficient of variation
519        self.load_balance_efficiency = (1.0 - cv.min(1.0)).max(0.0);
520    }
521}
522
523/// Error types specific to distributed computing
524#[derive(Debug, Clone)]
525pub enum DistributedError {
526    /// Node communication failure
527    CommunicationError(String),
528    /// Node timeout
529    NodeTimeout(NodeId),
530    /// Node failure
531    NodeFailure(NodeId, String),
532    /// Chunk processing error
533    ChunkError(ChunkId, String),
534    /// Synchronization error
535    SyncError(String),
536    /// Checkpoint error
537    CheckpointError(String),
538    /// Configuration error
539    ConfigError(String),
540    /// Resource exhaustion
541    ResourceExhausted(String),
542}
543
544impl std::fmt::Display for DistributedError {
545    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
546        match self {
547            Self::CommunicationError(msg) => write!(f, "Communication error: {}", msg),
548            Self::NodeTimeout(id) => write!(f, "Node {} timed out", id),
549            Self::NodeFailure(id, msg) => write!(f, "Node {} failed: {}", id, msg),
550            Self::ChunkError(id, msg) => write!(f, "Chunk {:?} error: {}", id, msg),
551            Self::SyncError(msg) => write!(f, "Synchronization error: {}", msg),
552            Self::CheckpointError(msg) => write!(f, "Checkpoint error: {}", msg),
553            Self::ConfigError(msg) => write!(f, "Configuration error: {}", msg),
554            Self::ResourceExhausted(msg) => write!(f, "Resource exhausted: {}", msg),
555        }
556    }
557}
558
559impl std::error::Error for DistributedError {}
560
561impl From<DistributedError> for IntegrateError {
562    fn from(err: DistributedError) -> Self {
563        IntegrateError::ComputationError(err.to_string())
564    }
565}
566
567/// Result type for distributed operations
568pub type DistributedResult<T> = std::result::Result<T, DistributedError>;
569
570#[cfg(test)]
571mod tests {
572    use super::*;
573    use std::net::{IpAddr, Ipv4Addr};
574
575    #[test]
576    fn test_node_id_display() {
577        let id = NodeId::new(42);
578        assert_eq!(format!("{}", id), "Node(42)");
579    }
580
581    #[test]
582    fn test_node_info_health_check() {
583        let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 8080);
584        let mut node = NodeInfo::new(NodeId::new(1), addr);
585        node.status = NodeStatus::Available;
586
587        assert!(node.is_healthy(Duration::from_secs(60)));
588
589        // Simulate old heartbeat
590        node.last_heartbeat = Instant::now() - Duration::from_secs(120);
591        assert!(!node.is_healthy(Duration::from_secs(60)));
592    }
593
594    #[test]
595    fn test_work_chunk_retry() {
596        let chunk: WorkChunk<f64> =
597            WorkChunk::new(ChunkId::new(1), JobId::new(1), (0.0, 1.0), Array1::zeros(3));
598
599        assert!(chunk.can_retry());
600        let mut chunk = chunk;
601        for _ in 0..3 {
602            chunk.increment_retry();
603        }
604        assert!(!chunk.can_retry());
605    }
606
607    #[test]
608    fn test_distributed_metrics_load_balance() {
609        let mut metrics = DistributedMetrics::default();
610
611        // Perfect balance
612        metrics.update_load_balance(&[1.0, 1.0, 1.0, 1.0]);
613        assert!((metrics.load_balance_efficiency - 1.0).abs() < 0.01);
614
615        // Imbalanced
616        metrics.update_load_balance(&[0.1, 0.1, 0.1, 3.7]);
617        assert!(metrics.load_balance_efficiency < 0.5);
618    }
619}