1use crate::common::IntegrateFloat;
7use crate::distributed::checkpointing::{
8 Checkpoint, CheckpointConfig, CheckpointGlobalState, CheckpointManager,
9 FaultToleranceCoordinator, RecoveryAction,
10};
11use crate::distributed::communication::{BoundaryExchanger, Communicator, MessageChannel};
12use crate::distributed::load_balancing::{ChunkDistributor, LoadBalancer, LoadBalancerConfig};
13use crate::distributed::node::{ComputeNode, NodeManager};
14use crate::distributed::types::{
15 BoundaryData, ChunkId, ChunkResult, ChunkResultStatus, DistributedConfig, DistributedError,
16 DistributedMetrics, DistributedResult, FaultToleranceMode, JobId, NodeId, NodeInfo, NodeStatus,
17 WorkChunk,
18};
19use crate::error::{IntegrateError, IntegrateResult};
20use crate::ode::types::{ODEMethod, ODEOptions};
21use scirs2_core::ndarray::{array, Array1, ArrayView1};
22use std::collections::{HashMap, VecDeque};
23use std::path::PathBuf;
24use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
25use std::sync::{Arc, Mutex, RwLock};
26use std::time::{Duration, Instant};
27
28pub struct DistributedODESolver<F: IntegrateFloat> {
30 node_manager: Arc<NodeManager>,
32 load_balancer: Arc<LoadBalancer<F>>,
34 checkpoint_manager: Arc<CheckpointManager<F>>,
36 fault_coordinator: Arc<FaultToleranceCoordinator<F>>,
38 channels: RwLock<HashMap<NodeId, Arc<MessageChannel<F>>>>,
40 boundary_exchanger: Arc<BoundaryExchanger<F>>,
42 config: DistributedConfig<F>,
44 next_job_id: AtomicU64,
46 shutdown: AtomicBool,
48 active_jobs: RwLock<HashMap<JobId, JobState<F>>>,
50 metrics: Mutex<DistributedMetrics>,
52}
53
54struct JobState<F: IntegrateFloat> {
56 job_id: JobId,
58 t_span: (F, F),
60 initial_state: Array1<F>,
62 total_chunks: usize,
64 completed_chunks: Vec<ChunkResult<F>>,
66 pending_chunks: Vec<ChunkId>,
68 in_progress_chunks: HashMap<ChunkId, NodeId>,
70 chunk_order: Vec<ChunkId>,
72 start_time: Instant,
74 last_checkpoint: Option<Instant>,
76 chunks_since_checkpoint: usize,
78}
79
80impl<F: IntegrateFloat> DistributedODESolver<F> {
81 pub fn new(config: DistributedConfig<F>) -> DistributedResult<Self> {
83 let node_manager = Arc::new(NodeManager::new(config.heartbeat_interval));
84
85 let load_balancer = Arc::new(LoadBalancer::new(
86 config.load_balancing,
87 LoadBalancerConfig::default(),
88 ));
89
90 let checkpoint_path = {
91 let mut p = std::env::temp_dir();
92 p.push("scirs_checkpoints");
93 p
94 };
95 let checkpoint_config = CheckpointConfig {
96 persist_to_disk: config.checkpointing_enabled,
97 interval_chunks: config.checkpoint_interval,
98 ..Default::default()
99 };
100
101 let checkpoint_manager =
102 Arc::new(CheckpointManager::new(checkpoint_path, checkpoint_config)?);
103
104 let fault_coordinator = Arc::new(FaultToleranceCoordinator::new(
105 Arc::clone(&checkpoint_manager),
106 config.fault_tolerance,
107 ));
108
109 let boundary_exchanger = Arc::new(BoundaryExchanger::new(config.communication_timeout));
110
111 Ok(Self {
112 node_manager,
113 load_balancer,
114 checkpoint_manager,
115 fault_coordinator,
116 channels: RwLock::new(HashMap::new()),
117 boundary_exchanger,
118 config,
119 next_job_id: AtomicU64::new(1),
120 shutdown: AtomicBool::new(false),
121 active_jobs: RwLock::new(HashMap::new()),
122 metrics: Mutex::new(DistributedMetrics::default()),
123 })
124 }
125
126 pub fn register_node(&self, node: NodeInfo) -> DistributedResult<()> {
128 let node_id = node.id;
129
130 self.node_manager
132 .register_node(node.address, node.capabilities.clone())?;
133
134 self.load_balancer.register_node(node_id)?;
136
137 let channel = Arc::new(MessageChannel::new(self.config.communication_timeout));
139 if let Ok(mut channels) = self.channels.write() {
140 channels.insert(node_id, channel);
141 }
142
143 Ok(())
144 }
145
146 pub fn deregister_node(&self, node_id: NodeId) -> DistributedResult<()> {
148 self.node_manager.deregister_node(node_id)?;
149 self.load_balancer.deregister_node(node_id)?;
150
151 if let Ok(mut channels) = self.channels.write() {
152 channels.remove(&node_id);
153 }
154
155 Ok(())
156 }
157
158 pub fn solve<Func>(
160 &self,
161 f: Func,
162 t_span: (F, F),
163 y0: Array1<F>,
164 options: Option<ODEOptions<F>>,
165 ) -> IntegrateResult<DistributedODEResult<F>>
166 where
167 Func: Fn(F, ArrayView1<F>) -> Array1<F> + Send + Sync + Clone + 'static,
168 {
169 let start_time = Instant::now();
170
171 let available_nodes = self.node_manager.get_available_nodes();
173 if available_nodes.is_empty() {
174 return Err(IntegrateError::ComputationError(
175 "No compute nodes available".to_string(),
176 ));
177 }
178
179 let job_id = JobId::new(self.next_job_id.fetch_add(1, Ordering::SeqCst));
181
182 let num_chunks = (available_nodes.len() * self.config.chunks_per_node).max(1);
184
185 let distributor = ChunkDistributor::new(job_id);
187 let chunks = distributor.create_chunks(t_span, y0.clone(), num_chunks);
188
189 let chunk_order: Vec<ChunkId> = chunks.iter().map(|c| c.id).collect();
191 let pending_chunks = chunk_order.clone();
192
193 let job_state = JobState {
194 job_id,
195 t_span,
196 initial_state: y0.clone(),
197 total_chunks: num_chunks,
198 completed_chunks: Vec::new(),
199 pending_chunks,
200 in_progress_chunks: HashMap::new(),
201 chunk_order,
202 start_time,
203 last_checkpoint: None,
204 chunks_since_checkpoint: 0,
205 };
206
207 if let Ok(mut jobs) = self.active_jobs.write() {
209 jobs.insert(job_id, job_state);
210 }
211
212 self.distribute_chunks(job_id, chunks, &available_nodes, &f)?;
214
215 let result = self.wait_for_completion(job_id, &f)?;
217
218 if let Ok(mut metrics) = self.metrics.lock() {
220 metrics.total_processing_time += start_time.elapsed();
221 }
222
223 if let Ok(mut jobs) = self.active_jobs.write() {
225 jobs.remove(&job_id);
226 }
227
228 Ok(result)
229 }
230
231 fn distribute_chunks<Func>(
233 &self,
234 job_id: JobId,
235 chunks: Vec<WorkChunk<F>>,
236 nodes: &[NodeInfo],
237 f: &Func,
238 ) -> IntegrateResult<()>
239 where
240 Func: Fn(F, ArrayView1<F>) -> Array1<F> + Send + Sync + Clone + 'static,
241 {
242 for chunk in chunks {
243 let node_id = self
244 .load_balancer
245 .assign_chunk(&chunk, nodes)
246 .map_err(|e| IntegrateError::ComputationError(e.to_string()))?;
247
248 if let Ok(mut jobs) = self.active_jobs.write() {
250 if let Some(job) = jobs.get_mut(&job_id) {
251 job.pending_chunks.retain(|id| *id != chunk.id);
252 job.in_progress_chunks.insert(chunk.id, node_id);
253 }
254 }
255
256 }
259
260 Ok(())
261 }
262
263 fn wait_for_completion<Func>(
265 &self,
266 job_id: JobId,
267 f: &Func,
268 ) -> IntegrateResult<DistributedODEResult<F>>
269 where
270 Func: Fn(F, ArrayView1<F>) -> Array1<F> + Send + Sync + Clone + 'static,
271 {
272 let timeout = Duration::from_secs(3600); let deadline = Instant::now() + timeout;
274
275 loop {
276 if Instant::now() > deadline {
277 return Err(IntegrateError::ConvergenceError(
278 "Distributed solve timeout".to_string(),
279 ));
280 }
281
282 let (is_complete, needs_processing) = {
284 let jobs = self.active_jobs.read().map_err(|_| {
285 IntegrateError::ComputationError("Failed to read job state".to_string())
286 })?;
287
288 if let Some(job) = jobs.get(&job_id) {
289 let complete =
290 job.pending_chunks.is_empty() && job.in_progress_chunks.is_empty();
291 let needs = !job.in_progress_chunks.is_empty();
292 (complete, needs)
293 } else {
294 return Err(IntegrateError::ComputationError(
295 "Job not found".to_string(),
296 ));
297 }
298 };
299
300 if is_complete {
301 break;
302 }
303
304 if needs_processing {
305 self.process_pending_chunks(job_id, f)?;
307 }
308
309 std::thread::sleep(Duration::from_millis(10));
310 }
311
312 self.assemble_result(job_id)
314 }
315
316 fn process_pending_chunks<Func>(&self, job_id: JobId, f: &Func) -> IntegrateResult<()>
321 where
322 Func: Fn(F, ArrayView1<F>) -> Array1<F> + Send + Sync + Clone + 'static,
323 {
324 let ordered_chunks: Vec<(ChunkId, NodeId, usize)> = {
326 let jobs = self.active_jobs.read().map_err(|_| {
327 IntegrateError::ComputationError("Failed to read job state".to_string())
328 })?;
329
330 if let Some(job) = jobs.get(&job_id) {
331 let mut items: Vec<(ChunkId, NodeId, usize)> = job
332 .in_progress_chunks
333 .iter()
334 .map(|(chunk_id, node_id)| {
335 let idx = job
336 .chunk_order
337 .iter()
338 .position(|id| id == chunk_id)
339 .unwrap_or(0);
340 (*chunk_id, *node_id, idx)
341 })
342 .collect();
343 items.sort_by_key(|&(_, _, idx)| idx);
345 items
346 } else {
347 Vec::new()
348 }
349 };
350
351 for (chunk_id, node_id, idx) in ordered_chunks {
353 let chunk = {
355 let jobs = self.active_jobs.read().map_err(|_| {
356 IntegrateError::ComputationError("Failed to read job state".to_string())
357 })?;
358 let job = jobs
359 .get(&job_id)
360 .ok_or_else(|| IntegrateError::ComputationError("Job not found".to_string()))?;
361
362 let (t_start, t_end) = job.t_span;
363 let dt = (t_end - t_start) / F::from(job.total_chunks).unwrap_or(F::one());
364
365 let chunk_t_start = t_start + dt * F::from(idx).unwrap_or(F::zero());
366 let chunk_t_end = if idx == job.total_chunks - 1 {
367 t_end
368 } else {
369 t_start + dt * F::from(idx + 1).unwrap_or(F::one())
370 };
371
372 let initial_state = if idx == 0 {
374 job.initial_state.clone()
375 } else {
376 let prev_chunk_id = job.chunk_order.get(idx - 1).ok_or_else(|| {
377 IntegrateError::ComputationError(
378 "Previous chunk not found in order".to_string(),
379 )
380 })?;
381 job.completed_chunks
382 .iter()
383 .find(|r| r.chunk_id == *prev_chunk_id)
384 .map(|r| r.final_state.clone())
385 .unwrap_or_else(|| job.initial_state.clone())
386 };
387
388 WorkChunk::new(
389 chunk_id,
390 job_id,
391 (chunk_t_start, chunk_t_end),
392 initial_state,
393 )
394 };
395
396 let result = self.process_single_chunk(&chunk, node_id, f)?;
397
398 if let Ok(mut jobs) = self.active_jobs.write() {
400 if let Some(job) = jobs.get_mut(&job_id) {
401 job.in_progress_chunks.remove(&chunk_id);
402 job.completed_chunks.push(result);
403 job.chunks_since_checkpoint += 1;
404
405 if self.config.checkpointing_enabled
407 && self
408 .checkpoint_manager
409 .should_checkpoint(job.chunks_since_checkpoint)
410 {
411 let global_state = CheckpointGlobalState {
412 iteration: 0,
413 chunks_completed: job.completed_chunks.len(),
414 chunks_remaining: job.pending_chunks.len()
415 + job.in_progress_chunks.len(),
416 current_time: F::zero(),
417 error_estimate: F::zero(),
418 };
419
420 let _ = self.checkpoint_manager.create_checkpoint(
421 job_id,
422 job.completed_chunks.clone(),
423 job.in_progress_chunks.keys().cloned().collect(),
424 global_state,
425 );
426
427 job.chunks_since_checkpoint = 0;
428 job.last_checkpoint = Some(Instant::now());
429 }
430 }
431 }
432
433 let processing_time = Duration::from_millis(10); self.load_balancer.report_completion(
436 node_id,
437 chunk.estimated_cost,
438 processing_time,
439 true,
440 );
441 }
442
443 Ok(())
444 }
445
446 fn process_single_chunk<Func>(
448 &self,
449 chunk: &WorkChunk<F>,
450 node_id: NodeId,
451 f: &Func,
452 ) -> IntegrateResult<ChunkResult<F>>
453 where
454 Func: Fn(F, ArrayView1<F>) -> Array1<F> + Send + Sync + Clone + 'static,
455 {
456 let start_time = Instant::now();
457
458 let (t_start, t_end) = chunk.time_interval;
460 let mut t = t_start;
461 let mut y = chunk.initial_state.clone();
462
463 let n_steps = 100;
464 let h = (t_end - t_start) / F::from(n_steps).unwrap_or(F::one());
465
466 let mut time_points = vec![t_start];
467 let mut states = vec![y.clone()];
468
469 for _ in 0..n_steps {
470 let k1 = f(t, y.view());
472 let k2 = f(
473 t + h / F::from(2.0).unwrap_or(F::one()),
474 (&y + &(&k1 * h / F::from(2.0).unwrap_or(F::one()))).view(),
475 );
476 let k3 = f(
477 t + h / F::from(2.0).unwrap_or(F::one()),
478 (&y + &(&k2 * h / F::from(2.0).unwrap_or(F::one()))).view(),
479 );
480 let k4 = f(t + h, (&y + &(&k3 * h)).view());
481
482 y = &y
483 + &((&k1
484 + &(&k2 * F::from(2.0).unwrap_or(F::one()))
485 + &(&k3 * F::from(2.0).unwrap_or(F::one()))
486 + &k4)
487 * h
488 / F::from(6.0).unwrap_or(F::one()));
489 t += h;
490
491 time_points.push(t);
492 states.push(y.clone());
493 }
494
495 let final_state = y.clone();
496 let final_derivative = Some(f(t, y.view()));
497
498 Ok(ChunkResult {
499 chunk_id: chunk.id,
500 node_id,
501 time_points,
502 states,
503 final_state,
504 final_derivative,
505 error_estimate: F::from(1e-6).unwrap_or(F::epsilon()),
506 processing_time: start_time.elapsed(),
507 memory_used: 0,
508 status: ChunkResultStatus::Success,
509 })
510 }
511
512 fn assemble_result(&self, job_id: JobId) -> IntegrateResult<DistributedODEResult<F>> {
514 let jobs = self.active_jobs.read().map_err(|_| {
515 IntegrateError::ComputationError("Failed to read job state".to_string())
516 })?;
517
518 let job = jobs
519 .get(&job_id)
520 .ok_or_else(|| IntegrateError::ComputationError("Job not found".to_string()))?;
521
522 let mut sorted_results: Vec<_> = job.completed_chunks.clone();
524 sorted_results.sort_by_key(|r| {
525 job.chunk_order
526 .iter()
527 .position(|id| *id == r.chunk_id)
528 .unwrap_or(usize::MAX)
529 });
530
531 let mut t_all = Vec::new();
533 let mut y_all = Vec::new();
534
535 for (i, result) in sorted_results.iter().enumerate() {
536 let skip_first = if i > 0 { 1 } else { 0 };
537 t_all.extend(result.time_points.iter().skip(skip_first).cloned());
538 y_all.extend(result.states.iter().skip(skip_first).cloned());
539 }
540
541 let total_time = job.start_time.elapsed();
542
543 let metrics = self.metrics.lock().map(|m| m.clone()).unwrap_or_default();
545
546 Ok(DistributedODEResult {
547 t: t_all,
548 y: y_all,
549 job_id,
550 chunks_processed: job.completed_chunks.len(),
551 nodes_used: job
552 .completed_chunks
553 .iter()
554 .map(|r| r.node_id)
555 .collect::<std::collections::HashSet<_>>()
556 .len(),
557 total_time,
558 metrics,
559 })
560 }
561
562 pub fn shutdown(&self) {
564 self.shutdown.store(true, Ordering::Relaxed);
565 self.node_manager.stop_health_monitoring();
566 }
567
568 pub fn get_metrics(&self) -> DistributedMetrics {
570 self.metrics.lock().map(|m| m.clone()).unwrap_or_default()
571 }
572}
573
574#[derive(Debug, Clone)]
576pub struct DistributedODEResult<F: IntegrateFloat> {
577 pub t: Vec<F>,
579 pub y: Vec<Array1<F>>,
581 pub job_id: JobId,
583 pub chunks_processed: usize,
585 pub nodes_used: usize,
587 pub total_time: Duration,
589 pub metrics: DistributedMetrics,
591}
592
593impl<F: IntegrateFloat> DistributedODEResult<F> {
594 pub fn final_state(&self) -> Option<&Array1<F>> {
596 self.y.last()
597 }
598
599 pub fn state_at(&self, index: usize) -> Option<&Array1<F>> {
601 self.y.get(index)
602 }
603
604 pub fn len(&self) -> usize {
606 self.t.len()
607 }
608
609 pub fn is_empty(&self) -> bool {
611 self.t.is_empty()
612 }
613
614 pub fn interpolate(&self, t_target: F) -> Option<Array1<F>> {
616 if self.t.is_empty() {
617 return None;
618 }
619
620 let mut left_idx = 0;
622 for (i, &t) in self.t.iter().enumerate() {
623 if t <= t_target {
624 left_idx = i;
625 } else {
626 break;
627 }
628 }
629
630 let right_idx = (left_idx + 1).min(self.t.len() - 1);
631
632 if left_idx == right_idx {
633 return self.y.get(left_idx).cloned();
634 }
635
636 let t_left = self.t[left_idx];
638 let t_right = self.t[right_idx];
639 let dt = t_right - t_left;
640
641 if dt.abs() < F::epsilon() {
642 return self.y.get(left_idx).cloned();
643 }
644
645 let alpha = (t_target - t_left) / dt;
646 let y_left = &self.y[left_idx];
647 let y_right = &self.y[right_idx];
648
649 Some(y_left * (F::one() - alpha) + y_right * alpha)
650 }
651}
652
653pub struct DistributedODESolverBuilder<F: IntegrateFloat> {
655 config: DistributedConfig<F>,
656}
657
658impl<F: IntegrateFloat> DistributedODESolverBuilder<F> {
659 pub fn new() -> Self {
661 Self {
662 config: DistributedConfig::default(),
663 }
664 }
665
666 pub fn tolerance(mut self, tol: F) -> Self {
668 self.config.tolerance = tol;
669 self
670 }
671
672 pub fn chunks_per_node(mut self, n: usize) -> Self {
674 self.config.chunks_per_node = n;
675 self
676 }
677
678 pub fn with_checkpointing(mut self, interval: usize) -> Self {
680 self.config.checkpointing_enabled = true;
681 self.config.checkpoint_interval = interval;
682 self
683 }
684
685 pub fn fault_tolerance(mut self, mode: FaultToleranceMode) -> Self {
687 self.config.fault_tolerance = mode;
688 self
689 }
690
691 pub fn timeout(mut self, timeout: Duration) -> Self {
693 self.config.communication_timeout = timeout;
694 self
695 }
696
697 pub fn build(self) -> DistributedResult<DistributedODESolver<F>> {
699 DistributedODESolver::new(self.config)
700 }
701}
702
703impl<F: IntegrateFloat> Default for DistributedODESolverBuilder<F> {
704 fn default() -> Self {
705 Self::new()
706 }
707}
708
709#[cfg(test)]
710mod tests {
711 use super::*;
712 use crate::distributed::types::NodeCapabilities;
713 use std::net::{IpAddr, Ipv4Addr, SocketAddr};
714
715 fn create_test_node(id: u64) -> NodeInfo {
716 let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 8080 + id as u16);
717 let mut info = NodeInfo::new(NodeId::new(id), addr);
718 info.capabilities = NodeCapabilities::default();
719 info.status = NodeStatus::Available;
720 info
721 }
722
723 #[test]
724 fn test_distributed_solver_creation() {
725 let config = DistributedConfig::<f64>::default();
726 let solver = DistributedODESolver::new(config);
727 assert!(solver.is_ok());
728 }
729
730 #[test]
731 fn test_distributed_solver_node_registration() {
732 let config = DistributedConfig::<f64>::default();
733 let solver = DistributedODESolver::new(config).expect("Failed to create solver");
734
735 let node = create_test_node(1);
736 let result = solver.register_node(node);
737 assert!(result.is_ok());
738 }
739
740 #[test]
741 fn test_distributed_solve_simple_ode() {
742 let config = DistributedConfig::<f64>::default();
743 let solver = DistributedODESolver::new(config).expect("Failed to create solver");
744
745 for i in 0..2 {
747 let node = create_test_node(i);
748 solver.register_node(node).expect("Failed to register node");
749 }
750
751 let f = |_t: f64, y: ArrayView1<f64>| array![-y[0]];
753 let y0 = array![1.0];
754
755 let result = solver.solve(f, (0.0, 1.0), y0, None);
756 assert!(result.is_ok());
757
758 let result = result.expect("Solve failed");
759 assert!(!result.t.is_empty());
760 assert!(!result.y.is_empty());
761
762 let expected = (-1.0_f64).exp();
764 let actual = result.final_state().expect("No final state")[0];
765 assert!((actual - expected).abs() < 0.01);
766 }
767
768 #[test]
769 fn test_distributed_result_interpolation() {
770 let result = DistributedODEResult::<f64> {
771 t: vec![0.0, 0.5, 1.0],
772 y: vec![array![1.0], array![0.6], array![0.4]],
773 job_id: JobId::new(1),
774 chunks_processed: 1,
775 nodes_used: 1,
776 total_time: Duration::from_secs(1),
777 metrics: DistributedMetrics::default(),
778 };
779
780 let interpolated = result.interpolate(0.25).expect("Interpolation failed");
781 assert!((interpolated[0] - 0.8_f64).abs() < 0.01_f64);
782 }
783
784 #[test]
785 fn test_solver_builder() {
786 let solver = DistributedODESolverBuilder::<f64>::new()
787 .tolerance(1e-8)
788 .chunks_per_node(8)
789 .with_checkpointing(5)
790 .fault_tolerance(FaultToleranceMode::Standard)
791 .timeout(Duration::from_secs(60))
792 .build();
793
794 assert!(solver.is_ok());
795 }
796}