1use crate::common::IntegrateFloat;
7use crate::distributed::checkpointing::{
8 Checkpoint, CheckpointConfig, CheckpointGlobalState, CheckpointManager,
9 FaultToleranceCoordinator, RecoveryAction,
10};
11use crate::distributed::communication::{BoundaryExchanger, Communicator, MessageChannel};
12use crate::distributed::load_balancing::{ChunkDistributor, LoadBalancer, LoadBalancerConfig};
13use crate::distributed::node::{ComputeNode, NodeManager};
14use crate::distributed::types::{
15 BoundaryData, ChunkId, ChunkResult, ChunkResultStatus, DistributedConfig, DistributedError,
16 DistributedMetrics, DistributedResult, FaultToleranceMode, JobId, NodeId, NodeInfo, NodeStatus,
17 WorkChunk,
18};
19use crate::error::{IntegrateError, IntegrateResult};
20use crate::ode::types::{ODEMethod, ODEOptions};
21use scirs2_core::ndarray::{array, Array1, ArrayView1};
22use std::collections::{HashMap, VecDeque};
23use std::path::PathBuf;
24use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
25use std::sync::{Arc, Mutex, RwLock};
26use std::time::{Duration, Instant};
27
28pub struct DistributedODESolver<F: IntegrateFloat> {
30 node_manager: Arc<NodeManager>,
32 load_balancer: Arc<LoadBalancer<F>>,
34 checkpoint_manager: Arc<CheckpointManager<F>>,
36 fault_coordinator: Arc<FaultToleranceCoordinator<F>>,
38 channels: RwLock<HashMap<NodeId, Arc<MessageChannel<F>>>>,
40 boundary_exchanger: Arc<BoundaryExchanger<F>>,
42 config: DistributedConfig<F>,
44 next_job_id: AtomicU64,
46 shutdown: AtomicBool,
48 active_jobs: RwLock<HashMap<JobId, JobState<F>>>,
50 metrics: Mutex<DistributedMetrics>,
52}
53
54struct JobState<F: IntegrateFloat> {
56 job_id: JobId,
58 t_span: (F, F),
60 initial_state: Array1<F>,
62 total_chunks: usize,
64 completed_chunks: Vec<ChunkResult<F>>,
66 pending_chunks: Vec<ChunkId>,
68 in_progress_chunks: HashMap<ChunkId, NodeId>,
70 chunk_order: Vec<ChunkId>,
72 start_time: Instant,
74 last_checkpoint: Option<Instant>,
76 chunks_since_checkpoint: usize,
78}
79
80impl<F: IntegrateFloat> DistributedODESolver<F> {
81 pub fn new(config: DistributedConfig<F>) -> DistributedResult<Self> {
83 let node_manager = Arc::new(NodeManager::new(config.heartbeat_interval));
84
85 let load_balancer = Arc::new(LoadBalancer::new(
86 config.load_balancing,
87 LoadBalancerConfig::default(),
88 ));
89
90 let checkpoint_path = PathBuf::from("/tmp/scirs_checkpoints");
91 let checkpoint_config = CheckpointConfig {
92 persist_to_disk: config.checkpointing_enabled,
93 interval_chunks: config.checkpoint_interval,
94 ..Default::default()
95 };
96
97 let checkpoint_manager =
98 Arc::new(CheckpointManager::new(checkpoint_path, checkpoint_config)?);
99
100 let fault_coordinator = Arc::new(FaultToleranceCoordinator::new(
101 Arc::clone(&checkpoint_manager),
102 config.fault_tolerance,
103 ));
104
105 let boundary_exchanger = Arc::new(BoundaryExchanger::new(config.communication_timeout));
106
107 Ok(Self {
108 node_manager,
109 load_balancer,
110 checkpoint_manager,
111 fault_coordinator,
112 channels: RwLock::new(HashMap::new()),
113 boundary_exchanger,
114 config,
115 next_job_id: AtomicU64::new(1),
116 shutdown: AtomicBool::new(false),
117 active_jobs: RwLock::new(HashMap::new()),
118 metrics: Mutex::new(DistributedMetrics::default()),
119 })
120 }
121
122 pub fn register_node(&self, node: NodeInfo) -> DistributedResult<()> {
124 let node_id = node.id;
125
126 self.node_manager
128 .register_node(node.address, node.capabilities.clone())?;
129
130 self.load_balancer.register_node(node_id)?;
132
133 let channel = Arc::new(MessageChannel::new(self.config.communication_timeout));
135 if let Ok(mut channels) = self.channels.write() {
136 channels.insert(node_id, channel);
137 }
138
139 Ok(())
140 }
141
142 pub fn deregister_node(&self, node_id: NodeId) -> DistributedResult<()> {
144 self.node_manager.deregister_node(node_id)?;
145 self.load_balancer.deregister_node(node_id)?;
146
147 if let Ok(mut channels) = self.channels.write() {
148 channels.remove(&node_id);
149 }
150
151 Ok(())
152 }
153
154 pub fn solve<Func>(
156 &self,
157 f: Func,
158 t_span: (F, F),
159 y0: Array1<F>,
160 options: Option<ODEOptions<F>>,
161 ) -> IntegrateResult<DistributedODEResult<F>>
162 where
163 Func: Fn(F, ArrayView1<F>) -> Array1<F> + Send + Sync + Clone + 'static,
164 {
165 let start_time = Instant::now();
166
167 let available_nodes = self.node_manager.get_available_nodes();
169 if available_nodes.is_empty() {
170 return Err(IntegrateError::ComputationError(
171 "No compute nodes available".to_string(),
172 ));
173 }
174
175 let job_id = JobId::new(self.next_job_id.fetch_add(1, Ordering::SeqCst));
177
178 let num_chunks = (available_nodes.len() * self.config.chunks_per_node).max(1);
180
181 let distributor = ChunkDistributor::new(job_id);
183 let chunks = distributor.create_chunks(t_span, y0.clone(), num_chunks);
184
185 let chunk_order: Vec<ChunkId> = chunks.iter().map(|c| c.id).collect();
187 let pending_chunks = chunk_order.clone();
188
189 let job_state = JobState {
190 job_id,
191 t_span,
192 initial_state: y0.clone(),
193 total_chunks: num_chunks,
194 completed_chunks: Vec::new(),
195 pending_chunks,
196 in_progress_chunks: HashMap::new(),
197 chunk_order,
198 start_time,
199 last_checkpoint: None,
200 chunks_since_checkpoint: 0,
201 };
202
203 if let Ok(mut jobs) = self.active_jobs.write() {
205 jobs.insert(job_id, job_state);
206 }
207
208 self.distribute_chunks(job_id, chunks, &available_nodes, &f)?;
210
211 let result = self.wait_for_completion(job_id, &f)?;
213
214 if let Ok(mut metrics) = self.metrics.lock() {
216 metrics.total_processing_time += start_time.elapsed();
217 }
218
219 if let Ok(mut jobs) = self.active_jobs.write() {
221 jobs.remove(&job_id);
222 }
223
224 Ok(result)
225 }
226
227 fn distribute_chunks<Func>(
229 &self,
230 job_id: JobId,
231 chunks: Vec<WorkChunk<F>>,
232 nodes: &[NodeInfo],
233 f: &Func,
234 ) -> IntegrateResult<()>
235 where
236 Func: Fn(F, ArrayView1<F>) -> Array1<F> + Send + Sync + Clone + 'static,
237 {
238 for chunk in chunks {
239 let node_id = self
240 .load_balancer
241 .assign_chunk(&chunk, nodes)
242 .map_err(|e| IntegrateError::ComputationError(e.to_string()))?;
243
244 if let Ok(mut jobs) = self.active_jobs.write() {
246 if let Some(job) = jobs.get_mut(&job_id) {
247 job.pending_chunks.retain(|id| *id != chunk.id);
248 job.in_progress_chunks.insert(chunk.id, node_id);
249 }
250 }
251
252 }
255
256 Ok(())
257 }
258
259 fn wait_for_completion<Func>(
261 &self,
262 job_id: JobId,
263 f: &Func,
264 ) -> IntegrateResult<DistributedODEResult<F>>
265 where
266 Func: Fn(F, ArrayView1<F>) -> Array1<F> + Send + Sync + Clone + 'static,
267 {
268 let timeout = Duration::from_secs(3600); let deadline = Instant::now() + timeout;
270
271 loop {
272 if Instant::now() > deadline {
273 return Err(IntegrateError::ConvergenceError(
274 "Distributed solve timeout".to_string(),
275 ));
276 }
277
278 let (is_complete, needs_processing) = {
280 let jobs = self.active_jobs.read().map_err(|_| {
281 IntegrateError::ComputationError("Failed to read job state".to_string())
282 })?;
283
284 if let Some(job) = jobs.get(&job_id) {
285 let complete =
286 job.pending_chunks.is_empty() && job.in_progress_chunks.is_empty();
287 let needs = !job.in_progress_chunks.is_empty();
288 (complete, needs)
289 } else {
290 return Err(IntegrateError::ComputationError(
291 "Job not found".to_string(),
292 ));
293 }
294 };
295
296 if is_complete {
297 break;
298 }
299
300 if needs_processing {
301 self.process_pending_chunks(job_id, f)?;
303 }
304
305 std::thread::sleep(Duration::from_millis(10));
306 }
307
308 self.assemble_result(job_id)
310 }
311
312 fn process_pending_chunks<Func>(&self, job_id: JobId, f: &Func) -> IntegrateResult<()>
317 where
318 Func: Fn(F, ArrayView1<F>) -> Array1<F> + Send + Sync + Clone + 'static,
319 {
320 let ordered_chunks: Vec<(ChunkId, NodeId, usize)> = {
322 let jobs = self.active_jobs.read().map_err(|_| {
323 IntegrateError::ComputationError("Failed to read job state".to_string())
324 })?;
325
326 if let Some(job) = jobs.get(&job_id) {
327 let mut items: Vec<(ChunkId, NodeId, usize)> = job
328 .in_progress_chunks
329 .iter()
330 .map(|(chunk_id, node_id)| {
331 let idx = job
332 .chunk_order
333 .iter()
334 .position(|id| id == chunk_id)
335 .unwrap_or(0);
336 (*chunk_id, *node_id, idx)
337 })
338 .collect();
339 items.sort_by_key(|&(_, _, idx)| idx);
341 items
342 } else {
343 Vec::new()
344 }
345 };
346
347 for (chunk_id, node_id, idx) in ordered_chunks {
349 let chunk = {
351 let jobs = self.active_jobs.read().map_err(|_| {
352 IntegrateError::ComputationError("Failed to read job state".to_string())
353 })?;
354 let job = jobs
355 .get(&job_id)
356 .ok_or_else(|| IntegrateError::ComputationError("Job not found".to_string()))?;
357
358 let (t_start, t_end) = job.t_span;
359 let dt = (t_end - t_start) / F::from(job.total_chunks).unwrap_or(F::one());
360
361 let chunk_t_start = t_start + dt * F::from(idx).unwrap_or(F::zero());
362 let chunk_t_end = if idx == job.total_chunks - 1 {
363 t_end
364 } else {
365 t_start + dt * F::from(idx + 1).unwrap_or(F::one())
366 };
367
368 let initial_state = if idx == 0 {
370 job.initial_state.clone()
371 } else {
372 let prev_chunk_id = job.chunk_order.get(idx - 1).ok_or_else(|| {
373 IntegrateError::ComputationError(
374 "Previous chunk not found in order".to_string(),
375 )
376 })?;
377 job.completed_chunks
378 .iter()
379 .find(|r| r.chunk_id == *prev_chunk_id)
380 .map(|r| r.final_state.clone())
381 .unwrap_or_else(|| job.initial_state.clone())
382 };
383
384 WorkChunk::new(
385 chunk_id,
386 job_id,
387 (chunk_t_start, chunk_t_end),
388 initial_state,
389 )
390 };
391
392 let result = self.process_single_chunk(&chunk, node_id, f)?;
393
394 if let Ok(mut jobs) = self.active_jobs.write() {
396 if let Some(job) = jobs.get_mut(&job_id) {
397 job.in_progress_chunks.remove(&chunk_id);
398 job.completed_chunks.push(result);
399 job.chunks_since_checkpoint += 1;
400
401 if self.config.checkpointing_enabled
403 && self
404 .checkpoint_manager
405 .should_checkpoint(job.chunks_since_checkpoint)
406 {
407 let global_state = CheckpointGlobalState {
408 iteration: 0,
409 chunks_completed: job.completed_chunks.len(),
410 chunks_remaining: job.pending_chunks.len()
411 + job.in_progress_chunks.len(),
412 current_time: F::zero(),
413 error_estimate: F::zero(),
414 };
415
416 let _ = self.checkpoint_manager.create_checkpoint(
417 job_id,
418 job.completed_chunks.clone(),
419 job.in_progress_chunks.keys().cloned().collect(),
420 global_state,
421 );
422
423 job.chunks_since_checkpoint = 0;
424 job.last_checkpoint = Some(Instant::now());
425 }
426 }
427 }
428
429 let processing_time = Duration::from_millis(10); self.load_balancer.report_completion(
432 node_id,
433 chunk.estimated_cost,
434 processing_time,
435 true,
436 );
437 }
438
439 Ok(())
440 }
441
442 fn process_single_chunk<Func>(
444 &self,
445 chunk: &WorkChunk<F>,
446 node_id: NodeId,
447 f: &Func,
448 ) -> IntegrateResult<ChunkResult<F>>
449 where
450 Func: Fn(F, ArrayView1<F>) -> Array1<F> + Send + Sync + Clone + 'static,
451 {
452 let start_time = Instant::now();
453
454 let (t_start, t_end) = chunk.time_interval;
456 let mut t = t_start;
457 let mut y = chunk.initial_state.clone();
458
459 let n_steps = 100;
460 let h = (t_end - t_start) / F::from(n_steps).unwrap_or(F::one());
461
462 let mut time_points = vec![t_start];
463 let mut states = vec![y.clone()];
464
465 for _ in 0..n_steps {
466 let k1 = f(t, y.view());
468 let k2 = f(
469 t + h / F::from(2.0).unwrap_or(F::one()),
470 (&y + &(&k1 * h / F::from(2.0).unwrap_or(F::one()))).view(),
471 );
472 let k3 = f(
473 t + h / F::from(2.0).unwrap_or(F::one()),
474 (&y + &(&k2 * h / F::from(2.0).unwrap_or(F::one()))).view(),
475 );
476 let k4 = f(t + h, (&y + &(&k3 * h)).view());
477
478 y = &y
479 + &((&k1
480 + &(&k2 * F::from(2.0).unwrap_or(F::one()))
481 + &(&k3 * F::from(2.0).unwrap_or(F::one()))
482 + &k4)
483 * h
484 / F::from(6.0).unwrap_or(F::one()));
485 t += h;
486
487 time_points.push(t);
488 states.push(y.clone());
489 }
490
491 let final_state = y.clone();
492 let final_derivative = Some(f(t, y.view()));
493
494 Ok(ChunkResult {
495 chunk_id: chunk.id,
496 node_id,
497 time_points,
498 states,
499 final_state,
500 final_derivative,
501 error_estimate: F::from(1e-6).unwrap_or(F::epsilon()),
502 processing_time: start_time.elapsed(),
503 memory_used: 0,
504 status: ChunkResultStatus::Success,
505 })
506 }
507
508 fn assemble_result(&self, job_id: JobId) -> IntegrateResult<DistributedODEResult<F>> {
510 let jobs = self.active_jobs.read().map_err(|_| {
511 IntegrateError::ComputationError("Failed to read job state".to_string())
512 })?;
513
514 let job = jobs
515 .get(&job_id)
516 .ok_or_else(|| IntegrateError::ComputationError("Job not found".to_string()))?;
517
518 let mut sorted_results: Vec<_> = job.completed_chunks.clone();
520 sorted_results.sort_by_key(|r| {
521 job.chunk_order
522 .iter()
523 .position(|id| *id == r.chunk_id)
524 .unwrap_or(usize::MAX)
525 });
526
527 let mut t_all = Vec::new();
529 let mut y_all = Vec::new();
530
531 for (i, result) in sorted_results.iter().enumerate() {
532 let skip_first = if i > 0 { 1 } else { 0 };
533 t_all.extend(result.time_points.iter().skip(skip_first).cloned());
534 y_all.extend(result.states.iter().skip(skip_first).cloned());
535 }
536
537 let total_time = job.start_time.elapsed();
538
539 let metrics = self.metrics.lock().map(|m| m.clone()).unwrap_or_default();
541
542 Ok(DistributedODEResult {
543 t: t_all,
544 y: y_all,
545 job_id,
546 chunks_processed: job.completed_chunks.len(),
547 nodes_used: job
548 .completed_chunks
549 .iter()
550 .map(|r| r.node_id)
551 .collect::<std::collections::HashSet<_>>()
552 .len(),
553 total_time,
554 metrics,
555 })
556 }
557
558 pub fn shutdown(&self) {
560 self.shutdown.store(true, Ordering::Relaxed);
561 self.node_manager.stop_health_monitoring();
562 }
563
564 pub fn get_metrics(&self) -> DistributedMetrics {
566 self.metrics.lock().map(|m| m.clone()).unwrap_or_default()
567 }
568}
569
570#[derive(Debug, Clone)]
572pub struct DistributedODEResult<F: IntegrateFloat> {
573 pub t: Vec<F>,
575 pub y: Vec<Array1<F>>,
577 pub job_id: JobId,
579 pub chunks_processed: usize,
581 pub nodes_used: usize,
583 pub total_time: Duration,
585 pub metrics: DistributedMetrics,
587}
588
589impl<F: IntegrateFloat> DistributedODEResult<F> {
590 pub fn final_state(&self) -> Option<&Array1<F>> {
592 self.y.last()
593 }
594
595 pub fn state_at(&self, index: usize) -> Option<&Array1<F>> {
597 self.y.get(index)
598 }
599
600 pub fn len(&self) -> usize {
602 self.t.len()
603 }
604
605 pub fn is_empty(&self) -> bool {
607 self.t.is_empty()
608 }
609
610 pub fn interpolate(&self, t_target: F) -> Option<Array1<F>> {
612 if self.t.is_empty() {
613 return None;
614 }
615
616 let mut left_idx = 0;
618 for (i, &t) in self.t.iter().enumerate() {
619 if t <= t_target {
620 left_idx = i;
621 } else {
622 break;
623 }
624 }
625
626 let right_idx = (left_idx + 1).min(self.t.len() - 1);
627
628 if left_idx == right_idx {
629 return self.y.get(left_idx).cloned();
630 }
631
632 let t_left = self.t[left_idx];
634 let t_right = self.t[right_idx];
635 let dt = t_right - t_left;
636
637 if dt.abs() < F::epsilon() {
638 return self.y.get(left_idx).cloned();
639 }
640
641 let alpha = (t_target - t_left) / dt;
642 let y_left = &self.y[left_idx];
643 let y_right = &self.y[right_idx];
644
645 Some(y_left * (F::one() - alpha) + y_right * alpha)
646 }
647}
648
649pub struct DistributedODESolverBuilder<F: IntegrateFloat> {
651 config: DistributedConfig<F>,
652}
653
654impl<F: IntegrateFloat> DistributedODESolverBuilder<F> {
655 pub fn new() -> Self {
657 Self {
658 config: DistributedConfig::default(),
659 }
660 }
661
662 pub fn tolerance(mut self, tol: F) -> Self {
664 self.config.tolerance = tol;
665 self
666 }
667
668 pub fn chunks_per_node(mut self, n: usize) -> Self {
670 self.config.chunks_per_node = n;
671 self
672 }
673
674 pub fn with_checkpointing(mut self, interval: usize) -> Self {
676 self.config.checkpointing_enabled = true;
677 self.config.checkpoint_interval = interval;
678 self
679 }
680
681 pub fn fault_tolerance(mut self, mode: FaultToleranceMode) -> Self {
683 self.config.fault_tolerance = mode;
684 self
685 }
686
687 pub fn timeout(mut self, timeout: Duration) -> Self {
689 self.config.communication_timeout = timeout;
690 self
691 }
692
693 pub fn build(self) -> DistributedResult<DistributedODESolver<F>> {
695 DistributedODESolver::new(self.config)
696 }
697}
698
699impl<F: IntegrateFloat> Default for DistributedODESolverBuilder<F> {
700 fn default() -> Self {
701 Self::new()
702 }
703}
704
705#[cfg(test)]
706mod tests {
707 use super::*;
708 use crate::distributed::types::NodeCapabilities;
709 use std::net::{IpAddr, Ipv4Addr, SocketAddr};
710
711 fn create_test_node(id: u64) -> NodeInfo {
712 let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 8080 + id as u16);
713 let mut info = NodeInfo::new(NodeId::new(id), addr);
714 info.capabilities = NodeCapabilities::default();
715 info.status = NodeStatus::Available;
716 info
717 }
718
719 #[test]
720 fn test_distributed_solver_creation() {
721 let config = DistributedConfig::<f64>::default();
722 let solver = DistributedODESolver::new(config);
723 assert!(solver.is_ok());
724 }
725
726 #[test]
727 fn test_distributed_solver_node_registration() {
728 let config = DistributedConfig::<f64>::default();
729 let solver = DistributedODESolver::new(config).expect("Failed to create solver");
730
731 let node = create_test_node(1);
732 let result = solver.register_node(node);
733 assert!(result.is_ok());
734 }
735
736 #[test]
737 fn test_distributed_solve_simple_ode() {
738 let config = DistributedConfig::<f64>::default();
739 let solver = DistributedODESolver::new(config).expect("Failed to create solver");
740
741 for i in 0..2 {
743 let node = create_test_node(i);
744 solver.register_node(node).expect("Failed to register node");
745 }
746
747 let f = |_t: f64, y: ArrayView1<f64>| array![-y[0]];
749 let y0 = array![1.0];
750
751 let result = solver.solve(f, (0.0, 1.0), y0, None);
752 assert!(result.is_ok());
753
754 let result = result.expect("Solve failed");
755 assert!(!result.t.is_empty());
756 assert!(!result.y.is_empty());
757
758 let expected = (-1.0_f64).exp();
760 let actual = result.final_state().expect("No final state")[0];
761 assert!((actual - expected).abs() < 0.01);
762 }
763
764 #[test]
765 fn test_distributed_result_interpolation() {
766 let result = DistributedODEResult::<f64> {
767 t: vec![0.0, 0.5, 1.0],
768 y: vec![array![1.0], array![0.6], array![0.4]],
769 job_id: JobId::new(1),
770 chunks_processed: 1,
771 nodes_used: 1,
772 total_time: Duration::from_secs(1),
773 metrics: DistributedMetrics::default(),
774 };
775
776 let interpolated = result.interpolate(0.25).expect("Interpolation failed");
777 assert!((interpolated[0] - 0.8_f64).abs() < 0.01_f64);
778 }
779
780 #[test]
781 fn test_solver_builder() {
782 let solver = DistributedODESolverBuilder::<f64>::new()
783 .tolerance(1e-8)
784 .chunks_per_node(8)
785 .with_checkpointing(5)
786 .fault_tolerance(FaultToleranceMode::Standard)
787 .timeout(Duration::from_secs(60))
788 .build();
789
790 assert!(solver.is_ok());
791 }
792}