Skip to main content

scirs2_integrate/distributed/
solver.rs

1//! Distributed ODE solver implementation
2//!
3//! This module provides the main distributed ODE solver that coordinates
4//! work distribution across compute nodes for large-scale integration problems.
5
6use crate::common::IntegrateFloat;
7use crate::distributed::checkpointing::{
8    Checkpoint, CheckpointConfig, CheckpointGlobalState, CheckpointManager,
9    FaultToleranceCoordinator, RecoveryAction,
10};
11use crate::distributed::communication::{BoundaryExchanger, Communicator, MessageChannel};
12use crate::distributed::load_balancing::{ChunkDistributor, LoadBalancer, LoadBalancerConfig};
13use crate::distributed::node::{ComputeNode, NodeManager};
14use crate::distributed::types::{
15    BoundaryData, ChunkId, ChunkResult, ChunkResultStatus, DistributedConfig, DistributedError,
16    DistributedMetrics, DistributedResult, FaultToleranceMode, JobId, NodeId, NodeInfo, NodeStatus,
17    WorkChunk,
18};
19use crate::error::{IntegrateError, IntegrateResult};
20use crate::ode::types::{ODEMethod, ODEOptions};
21use scirs2_core::ndarray::{array, Array1, ArrayView1};
22use std::collections::{HashMap, VecDeque};
23use std::path::PathBuf;
24use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
25use std::sync::{Arc, Mutex, RwLock};
26use std::time::{Duration, Instant};
27
28/// Distributed ODE solver
29pub struct DistributedODESolver<F: IntegrateFloat> {
30    /// Node manager
31    node_manager: Arc<NodeManager>,
32    /// Load balancer
33    load_balancer: Arc<LoadBalancer<F>>,
34    /// Checkpoint manager
35    checkpoint_manager: Arc<CheckpointManager<F>>,
36    /// Fault tolerance coordinator
37    fault_coordinator: Arc<FaultToleranceCoordinator<F>>,
38    /// Message channels per node
39    channels: RwLock<HashMap<NodeId, Arc<MessageChannel<F>>>>,
40    /// Boundary exchanger
41    boundary_exchanger: Arc<BoundaryExchanger<F>>,
42    /// Configuration
43    config: DistributedConfig<F>,
44    /// Next job ID
45    next_job_id: AtomicU64,
46    /// Shutdown flag
47    shutdown: AtomicBool,
48    /// Active jobs
49    active_jobs: RwLock<HashMap<JobId, JobState<F>>>,
50    /// Metrics
51    metrics: Mutex<DistributedMetrics>,
52}
53
54/// State of an active job
55struct JobState<F: IntegrateFloat> {
56    /// Job ID
57    job_id: JobId,
58    /// Time span
59    t_span: (F, F),
60    /// Initial state
61    initial_state: Array1<F>,
62    /// Total chunks
63    total_chunks: usize,
64    /// Completed chunks
65    completed_chunks: Vec<ChunkResult<F>>,
66    /// Pending chunks
67    pending_chunks: Vec<ChunkId>,
68    /// In-progress chunks
69    in_progress_chunks: HashMap<ChunkId, NodeId>,
70    /// Chunk ordering for assembly
71    chunk_order: Vec<ChunkId>,
72    /// Start time
73    start_time: Instant,
74    /// Last checkpoint time
75    last_checkpoint: Option<Instant>,
76    /// Chunks since last checkpoint
77    chunks_since_checkpoint: usize,
78}
79
80impl<F: IntegrateFloat> DistributedODESolver<F> {
81    /// Create a new distributed ODE solver
82    pub fn new(config: DistributedConfig<F>) -> DistributedResult<Self> {
83        let node_manager = Arc::new(NodeManager::new(config.heartbeat_interval));
84
85        let load_balancer = Arc::new(LoadBalancer::new(
86            config.load_balancing,
87            LoadBalancerConfig::default(),
88        ));
89
90        let checkpoint_path = PathBuf::from("/tmp/scirs_checkpoints");
91        let checkpoint_config = CheckpointConfig {
92            persist_to_disk: config.checkpointing_enabled,
93            interval_chunks: config.checkpoint_interval,
94            ..Default::default()
95        };
96
97        let checkpoint_manager =
98            Arc::new(CheckpointManager::new(checkpoint_path, checkpoint_config)?);
99
100        let fault_coordinator = Arc::new(FaultToleranceCoordinator::new(
101            Arc::clone(&checkpoint_manager),
102            config.fault_tolerance,
103        ));
104
105        let boundary_exchanger = Arc::new(BoundaryExchanger::new(config.communication_timeout));
106
107        Ok(Self {
108            node_manager,
109            load_balancer,
110            checkpoint_manager,
111            fault_coordinator,
112            channels: RwLock::new(HashMap::new()),
113            boundary_exchanger,
114            config,
115            next_job_id: AtomicU64::new(1),
116            shutdown: AtomicBool::new(false),
117            active_jobs: RwLock::new(HashMap::new()),
118            metrics: Mutex::new(DistributedMetrics::default()),
119        })
120    }
121
122    /// Register a compute node
123    pub fn register_node(&self, node: NodeInfo) -> DistributedResult<()> {
124        let node_id = node.id;
125
126        // Register with node manager
127        self.node_manager
128            .register_node(node.address, node.capabilities.clone())?;
129
130        // Register with load balancer
131        self.load_balancer.register_node(node_id)?;
132
133        // Create message channel
134        let channel = Arc::new(MessageChannel::new(self.config.communication_timeout));
135        if let Ok(mut channels) = self.channels.write() {
136            channels.insert(node_id, channel);
137        }
138
139        Ok(())
140    }
141
142    /// Deregister a compute node
143    pub fn deregister_node(&self, node_id: NodeId) -> DistributedResult<()> {
144        self.node_manager.deregister_node(node_id)?;
145        self.load_balancer.deregister_node(node_id)?;
146
147        if let Ok(mut channels) = self.channels.write() {
148            channels.remove(&node_id);
149        }
150
151        Ok(())
152    }
153
154    /// Solve an ODE problem distributedly
155    pub fn solve<Func>(
156        &self,
157        f: Func,
158        t_span: (F, F),
159        y0: Array1<F>,
160        options: Option<ODEOptions<F>>,
161    ) -> IntegrateResult<DistributedODEResult<F>>
162    where
163        Func: Fn(F, ArrayView1<F>) -> Array1<F> + Send + Sync + Clone + 'static,
164    {
165        let start_time = Instant::now();
166
167        // Get available nodes
168        let available_nodes = self.node_manager.get_available_nodes();
169        if available_nodes.is_empty() {
170            return Err(IntegrateError::ComputationError(
171                "No compute nodes available".to_string(),
172            ));
173        }
174
175        // Create job
176        let job_id = JobId::new(self.next_job_id.fetch_add(1, Ordering::SeqCst));
177
178        // Calculate number of chunks based on nodes
179        let num_chunks = (available_nodes.len() * self.config.chunks_per_node).max(1);
180
181        // Create chunk distributor and generate chunks
182        let distributor = ChunkDistributor::new(job_id);
183        let chunks = distributor.create_chunks(t_span, y0.clone(), num_chunks);
184
185        // Initialize job state
186        let chunk_order: Vec<ChunkId> = chunks.iter().map(|c| c.id).collect();
187        let pending_chunks = chunk_order.clone();
188
189        let job_state = JobState {
190            job_id,
191            t_span,
192            initial_state: y0.clone(),
193            total_chunks: num_chunks,
194            completed_chunks: Vec::new(),
195            pending_chunks,
196            in_progress_chunks: HashMap::new(),
197            chunk_order,
198            start_time,
199            last_checkpoint: None,
200            chunks_since_checkpoint: 0,
201        };
202
203        // Register job
204        if let Ok(mut jobs) = self.active_jobs.write() {
205            jobs.insert(job_id, job_state);
206        }
207
208        // Distribute initial work
209        self.distribute_chunks(job_id, chunks, &available_nodes, &f)?;
210
211        // Wait for completion
212        let result = self.wait_for_completion(job_id, &f)?;
213
214        // Update metrics
215        if let Ok(mut metrics) = self.metrics.lock() {
216            metrics.total_processing_time += start_time.elapsed();
217        }
218
219        // Cleanup job
220        if let Ok(mut jobs) = self.active_jobs.write() {
221            jobs.remove(&job_id);
222        }
223
224        Ok(result)
225    }
226
227    /// Distribute chunks to nodes
228    fn distribute_chunks<Func>(
229        &self,
230        job_id: JobId,
231        chunks: Vec<WorkChunk<F>>,
232        nodes: &[NodeInfo],
233        f: &Func,
234    ) -> IntegrateResult<()>
235    where
236        Func: Fn(F, ArrayView1<F>) -> Array1<F> + Send + Sync + Clone + 'static,
237    {
238        for chunk in chunks {
239            let node_id = self
240                .load_balancer
241                .assign_chunk(&chunk, nodes)
242                .map_err(|e| IntegrateError::ComputationError(e.to_string()))?;
243
244            // Record assignment
245            if let Ok(mut jobs) = self.active_jobs.write() {
246                if let Some(job) = jobs.get_mut(&job_id) {
247                    job.pending_chunks.retain(|id| *id != chunk.id);
248                    job.in_progress_chunks.insert(chunk.id, node_id);
249                }
250            }
251
252            // In a real implementation, this would send the chunk over the network
253            // For now, we simulate local processing
254        }
255
256        Ok(())
257    }
258
259    /// Wait for job completion
260    fn wait_for_completion<Func>(
261        &self,
262        job_id: JobId,
263        f: &Func,
264    ) -> IntegrateResult<DistributedODEResult<F>>
265    where
266        Func: Fn(F, ArrayView1<F>) -> Array1<F> + Send + Sync + Clone + 'static,
267    {
268        let timeout = Duration::from_secs(3600); // 1 hour timeout
269        let deadline = Instant::now() + timeout;
270
271        loop {
272            if Instant::now() > deadline {
273                return Err(IntegrateError::ConvergenceError(
274                    "Distributed solve timeout".to_string(),
275                ));
276            }
277
278            // Check completion
279            let (is_complete, needs_processing) = {
280                let jobs = self.active_jobs.read().map_err(|_| {
281                    IntegrateError::ComputationError("Failed to read job state".to_string())
282                })?;
283
284                if let Some(job) = jobs.get(&job_id) {
285                    let complete =
286                        job.pending_chunks.is_empty() && job.in_progress_chunks.is_empty();
287                    let needs = !job.in_progress_chunks.is_empty();
288                    (complete, needs)
289                } else {
290                    return Err(IntegrateError::ComputationError(
291                        "Job not found".to_string(),
292                    ));
293                }
294            };
295
296            if is_complete {
297                break;
298            }
299
300            if needs_processing {
301                // Simulate processing chunks
302                self.process_pending_chunks(job_id, f)?;
303            }
304
305            std::thread::sleep(Duration::from_millis(10));
306        }
307
308        // Assemble result
309        self.assemble_result(job_id)
310    }
311
312    /// Process pending chunks (simulation for local testing)
313    ///
314    /// Chunks are processed sequentially in chunk_order so that each chunk
315    /// can use the final state from the previous chunk as its initial state.
316    fn process_pending_chunks<Func>(&self, job_id: JobId, f: &Func) -> IntegrateResult<()>
317    where
318        Func: Fn(F, ArrayView1<F>) -> Array1<F> + Send + Sync + Clone + 'static,
319    {
320        // Get ordered list of in-progress chunk IDs and their assigned nodes
321        let ordered_chunks: Vec<(ChunkId, NodeId, usize)> = {
322            let jobs = self.active_jobs.read().map_err(|_| {
323                IntegrateError::ComputationError("Failed to read job state".to_string())
324            })?;
325
326            if let Some(job) = jobs.get(&job_id) {
327                let mut items: Vec<(ChunkId, NodeId, usize)> = job
328                    .in_progress_chunks
329                    .iter()
330                    .map(|(chunk_id, node_id)| {
331                        let idx = job
332                            .chunk_order
333                            .iter()
334                            .position(|id| id == chunk_id)
335                            .unwrap_or(0);
336                        (*chunk_id, *node_id, idx)
337                    })
338                    .collect();
339                // Sort by chunk order index so we process them sequentially
340                items.sort_by_key(|&(_, _, idx)| idx);
341                items
342            } else {
343                Vec::new()
344            }
345        };
346
347        // Process each chunk one at a time, in order
348        for (chunk_id, node_id, idx) in ordered_chunks {
349            // Build the work chunk with correct initial state from completed chunks
350            let chunk = {
351                let jobs = self.active_jobs.read().map_err(|_| {
352                    IntegrateError::ComputationError("Failed to read job state".to_string())
353                })?;
354                let job = jobs
355                    .get(&job_id)
356                    .ok_or_else(|| IntegrateError::ComputationError("Job not found".to_string()))?;
357
358                let (t_start, t_end) = job.t_span;
359                let dt = (t_end - t_start) / F::from(job.total_chunks).unwrap_or(F::one());
360
361                let chunk_t_start = t_start + dt * F::from(idx).unwrap_or(F::zero());
362                let chunk_t_end = if idx == job.total_chunks - 1 {
363                    t_end
364                } else {
365                    t_start + dt * F::from(idx + 1).unwrap_or(F::one())
366                };
367
368                // Get initial state from previous chunk result or job initial state
369                let initial_state = if idx == 0 {
370                    job.initial_state.clone()
371                } else {
372                    let prev_chunk_id = job.chunk_order.get(idx - 1).ok_or_else(|| {
373                        IntegrateError::ComputationError(
374                            "Previous chunk not found in order".to_string(),
375                        )
376                    })?;
377                    job.completed_chunks
378                        .iter()
379                        .find(|r| r.chunk_id == *prev_chunk_id)
380                        .map(|r| r.final_state.clone())
381                        .unwrap_or_else(|| job.initial_state.clone())
382                };
383
384                WorkChunk::new(
385                    chunk_id,
386                    job_id,
387                    (chunk_t_start, chunk_t_end),
388                    initial_state,
389                )
390            };
391
392            let result = self.process_single_chunk(&chunk, node_id, f)?;
393
394            // Update job state
395            if let Ok(mut jobs) = self.active_jobs.write() {
396                if let Some(job) = jobs.get_mut(&job_id) {
397                    job.in_progress_chunks.remove(&chunk_id);
398                    job.completed_chunks.push(result);
399                    job.chunks_since_checkpoint += 1;
400
401                    // Check if checkpoint is needed
402                    if self.config.checkpointing_enabled
403                        && self
404                            .checkpoint_manager
405                            .should_checkpoint(job.chunks_since_checkpoint)
406                    {
407                        let global_state = CheckpointGlobalState {
408                            iteration: 0,
409                            chunks_completed: job.completed_chunks.len(),
410                            chunks_remaining: job.pending_chunks.len()
411                                + job.in_progress_chunks.len(),
412                            current_time: F::zero(),
413                            error_estimate: F::zero(),
414                        };
415
416                        let _ = self.checkpoint_manager.create_checkpoint(
417                            job_id,
418                            job.completed_chunks.clone(),
419                            job.in_progress_chunks.keys().cloned().collect(),
420                            global_state,
421                        );
422
423                        job.chunks_since_checkpoint = 0;
424                        job.last_checkpoint = Some(Instant::now());
425                    }
426                }
427            }
428
429            // Update load balancer
430            let processing_time = Duration::from_millis(10); // Simulated
431            self.load_balancer.report_completion(
432                node_id,
433                chunk.estimated_cost,
434                processing_time,
435                true,
436            );
437        }
438
439        Ok(())
440    }
441
442    /// Process a single chunk using local ODE solver
443    fn process_single_chunk<Func>(
444        &self,
445        chunk: &WorkChunk<F>,
446        node_id: NodeId,
447        f: &Func,
448    ) -> IntegrateResult<ChunkResult<F>>
449    where
450        Func: Fn(F, ArrayView1<F>) -> Array1<F> + Send + Sync + Clone + 'static,
451    {
452        let start_time = Instant::now();
453
454        // Use RK4 for simplicity
455        let (t_start, t_end) = chunk.time_interval;
456        let mut t = t_start;
457        let mut y = chunk.initial_state.clone();
458
459        let n_steps = 100;
460        let h = (t_end - t_start) / F::from(n_steps).unwrap_or(F::one());
461
462        let mut time_points = vec![t_start];
463        let mut states = vec![y.clone()];
464
465        for _ in 0..n_steps {
466            // RK4 step
467            let k1 = f(t, y.view());
468            let k2 = f(
469                t + h / F::from(2.0).unwrap_or(F::one()),
470                (&y + &(&k1 * h / F::from(2.0).unwrap_or(F::one()))).view(),
471            );
472            let k3 = f(
473                t + h / F::from(2.0).unwrap_or(F::one()),
474                (&y + &(&k2 * h / F::from(2.0).unwrap_or(F::one()))).view(),
475            );
476            let k4 = f(t + h, (&y + &(&k3 * h)).view());
477
478            y = &y
479                + &((&k1
480                    + &(&k2 * F::from(2.0).unwrap_or(F::one()))
481                    + &(&k3 * F::from(2.0).unwrap_or(F::one()))
482                    + &k4)
483                    * h
484                    / F::from(6.0).unwrap_or(F::one()));
485            t += h;
486
487            time_points.push(t);
488            states.push(y.clone());
489        }
490
491        let final_state = y.clone();
492        let final_derivative = Some(f(t, y.view()));
493
494        Ok(ChunkResult {
495            chunk_id: chunk.id,
496            node_id,
497            time_points,
498            states,
499            final_state,
500            final_derivative,
501            error_estimate: F::from(1e-6).unwrap_or(F::epsilon()),
502            processing_time: start_time.elapsed(),
503            memory_used: 0,
504            status: ChunkResultStatus::Success,
505        })
506    }
507
508    /// Assemble final result from completed chunks
509    fn assemble_result(&self, job_id: JobId) -> IntegrateResult<DistributedODEResult<F>> {
510        let jobs = self.active_jobs.read().map_err(|_| {
511            IntegrateError::ComputationError("Failed to read job state".to_string())
512        })?;
513
514        let job = jobs
515            .get(&job_id)
516            .ok_or_else(|| IntegrateError::ComputationError("Job not found".to_string()))?;
517
518        // Sort results by chunk order
519        let mut sorted_results: Vec<_> = job.completed_chunks.clone();
520        sorted_results.sort_by_key(|r| {
521            job.chunk_order
522                .iter()
523                .position(|id| *id == r.chunk_id)
524                .unwrap_or(usize::MAX)
525        });
526
527        // Concatenate time points and states
528        let mut t_all = Vec::new();
529        let mut y_all = Vec::new();
530
531        for (i, result) in sorted_results.iter().enumerate() {
532            let skip_first = if i > 0 { 1 } else { 0 };
533            t_all.extend(result.time_points.iter().skip(skip_first).cloned());
534            y_all.extend(result.states.iter().skip(skip_first).cloned());
535        }
536
537        let total_time = job.start_time.elapsed();
538
539        // Get metrics
540        let metrics = self.metrics.lock().map(|m| m.clone()).unwrap_or_default();
541
542        Ok(DistributedODEResult {
543            t: t_all,
544            y: y_all,
545            job_id,
546            chunks_processed: job.completed_chunks.len(),
547            nodes_used: job
548                .completed_chunks
549                .iter()
550                .map(|r| r.node_id)
551                .collect::<std::collections::HashSet<_>>()
552                .len(),
553            total_time,
554            metrics,
555        })
556    }
557
558    /// Shutdown the solver
559    pub fn shutdown(&self) {
560        self.shutdown.store(true, Ordering::Relaxed);
561        self.node_manager.stop_health_monitoring();
562    }
563
564    /// Get solver metrics
565    pub fn get_metrics(&self) -> DistributedMetrics {
566        self.metrics.lock().map(|m| m.clone()).unwrap_or_default()
567    }
568}
569
570/// Result of a distributed ODE solve
571#[derive(Debug, Clone)]
572pub struct DistributedODEResult<F: IntegrateFloat> {
573    /// Time points
574    pub t: Vec<F>,
575    /// Solution states
576    pub y: Vec<Array1<F>>,
577    /// Job ID
578    pub job_id: JobId,
579    /// Number of chunks processed
580    pub chunks_processed: usize,
581    /// Number of nodes used
582    pub nodes_used: usize,
583    /// Total time taken
584    pub total_time: Duration,
585    /// Distributed metrics
586    pub metrics: DistributedMetrics,
587}
588
589impl<F: IntegrateFloat> DistributedODEResult<F> {
590    /// Get final state
591    pub fn final_state(&self) -> Option<&Array1<F>> {
592        self.y.last()
593    }
594
595    /// Get state at specific index
596    pub fn state_at(&self, index: usize) -> Option<&Array1<F>> {
597        self.y.get(index)
598    }
599
600    /// Get number of time points
601    pub fn len(&self) -> usize {
602        self.t.len()
603    }
604
605    /// Check if result is empty
606    pub fn is_empty(&self) -> bool {
607        self.t.is_empty()
608    }
609
610    /// Interpolate solution at a given time
611    pub fn interpolate(&self, t_target: F) -> Option<Array1<F>> {
612        if self.t.is_empty() {
613            return None;
614        }
615
616        // Find bracketing points
617        let mut left_idx = 0;
618        for (i, &t) in self.t.iter().enumerate() {
619            if t <= t_target {
620                left_idx = i;
621            } else {
622                break;
623            }
624        }
625
626        let right_idx = (left_idx + 1).min(self.t.len() - 1);
627
628        if left_idx == right_idx {
629            return self.y.get(left_idx).cloned();
630        }
631
632        // Linear interpolation
633        let t_left = self.t[left_idx];
634        let t_right = self.t[right_idx];
635        let dt = t_right - t_left;
636
637        if dt.abs() < F::epsilon() {
638            return self.y.get(left_idx).cloned();
639        }
640
641        let alpha = (t_target - t_left) / dt;
642        let y_left = &self.y[left_idx];
643        let y_right = &self.y[right_idx];
644
645        Some(y_left * (F::one() - alpha) + y_right * alpha)
646    }
647}
648
649/// Builder for distributed ODE solver
650pub struct DistributedODESolverBuilder<F: IntegrateFloat> {
651    config: DistributedConfig<F>,
652}
653
654impl<F: IntegrateFloat> DistributedODESolverBuilder<F> {
655    /// Create a new builder with default configuration
656    pub fn new() -> Self {
657        Self {
658            config: DistributedConfig::default(),
659        }
660    }
661
662    /// Set tolerance
663    pub fn tolerance(mut self, tol: F) -> Self {
664        self.config.tolerance = tol;
665        self
666    }
667
668    /// Set chunks per node
669    pub fn chunks_per_node(mut self, n: usize) -> Self {
670        self.config.chunks_per_node = n;
671        self
672    }
673
674    /// Enable checkpointing
675    pub fn with_checkpointing(mut self, interval: usize) -> Self {
676        self.config.checkpointing_enabled = true;
677        self.config.checkpoint_interval = interval;
678        self
679    }
680
681    /// Set fault tolerance mode
682    pub fn fault_tolerance(mut self, mode: FaultToleranceMode) -> Self {
683        self.config.fault_tolerance = mode;
684        self
685    }
686
687    /// Set communication timeout
688    pub fn timeout(mut self, timeout: Duration) -> Self {
689        self.config.communication_timeout = timeout;
690        self
691    }
692
693    /// Build the solver
694    pub fn build(self) -> DistributedResult<DistributedODESolver<F>> {
695        DistributedODESolver::new(self.config)
696    }
697}
698
699impl<F: IntegrateFloat> Default for DistributedODESolverBuilder<F> {
700    fn default() -> Self {
701        Self::new()
702    }
703}
704
705#[cfg(test)]
706mod tests {
707    use super::*;
708    use crate::distributed::types::NodeCapabilities;
709    use std::net::{IpAddr, Ipv4Addr, SocketAddr};
710
711    fn create_test_node(id: u64) -> NodeInfo {
712        let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 8080 + id as u16);
713        let mut info = NodeInfo::new(NodeId::new(id), addr);
714        info.capabilities = NodeCapabilities::default();
715        info.status = NodeStatus::Available;
716        info
717    }
718
719    #[test]
720    fn test_distributed_solver_creation() {
721        let config = DistributedConfig::<f64>::default();
722        let solver = DistributedODESolver::new(config);
723        assert!(solver.is_ok());
724    }
725
726    #[test]
727    fn test_distributed_solver_node_registration() {
728        let config = DistributedConfig::<f64>::default();
729        let solver = DistributedODESolver::new(config).expect("Failed to create solver");
730
731        let node = create_test_node(1);
732        let result = solver.register_node(node);
733        assert!(result.is_ok());
734    }
735
736    #[test]
737    fn test_distributed_solve_simple_ode() {
738        let config = DistributedConfig::<f64>::default();
739        let solver = DistributedODESolver::new(config).expect("Failed to create solver");
740
741        // Register test nodes
742        for i in 0..2 {
743            let node = create_test_node(i);
744            solver.register_node(node).expect("Failed to register node");
745        }
746
747        // Solve y' = -y, y(0) = 1
748        let f = |_t: f64, y: ArrayView1<f64>| array![-y[0]];
749        let y0 = array![1.0];
750
751        let result = solver.solve(f, (0.0, 1.0), y0, None);
752        assert!(result.is_ok());
753
754        let result = result.expect("Solve failed");
755        assert!(!result.t.is_empty());
756        assert!(!result.y.is_empty());
757
758        // Final value should be close to e^(-1)
759        let expected = (-1.0_f64).exp();
760        let actual = result.final_state().expect("No final state")[0];
761        assert!((actual - expected).abs() < 0.01);
762    }
763
764    #[test]
765    fn test_distributed_result_interpolation() {
766        let result = DistributedODEResult::<f64> {
767            t: vec![0.0, 0.5, 1.0],
768            y: vec![array![1.0], array![0.6], array![0.4]],
769            job_id: JobId::new(1),
770            chunks_processed: 1,
771            nodes_used: 1,
772            total_time: Duration::from_secs(1),
773            metrics: DistributedMetrics::default(),
774        };
775
776        let interpolated = result.interpolate(0.25).expect("Interpolation failed");
777        assert!((interpolated[0] - 0.8_f64).abs() < 0.01_f64);
778    }
779
780    #[test]
781    fn test_solver_builder() {
782        let solver = DistributedODESolverBuilder::<f64>::new()
783            .tolerance(1e-8)
784            .chunks_per_node(8)
785            .with_checkpointing(5)
786            .fault_tolerance(FaultToleranceMode::Standard)
787            .timeout(Duration::from_secs(60))
788            .build();
789
790        assert!(solver.is_ok());
791    }
792}