1use crate::common::IntegrateFloat;
7use crate::distributed::types::{
8 ChunkId, ChunkResult, ChunkResultStatus, DistributedError, DistributedResult,
9 FaultToleranceMode, JobId, NodeId,
10};
11use scirs2_core::ndarray::Array1;
12use std::collections::{HashMap, HashSet, VecDeque};
13use std::fs::{self, File};
14use std::io::{Read, Write};
15use std::path::{Path, PathBuf};
16use std::sync::atomic::{AtomicU64, Ordering};
17use std::sync::{Arc, Mutex, RwLock};
18use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
19
20pub struct CheckpointManager<F: IntegrateFloat> {
22 storage_path: PathBuf,
24 checkpoints: RwLock<HashMap<JobId, Vec<Checkpoint<F>>>>,
26 next_checkpoint_id: AtomicU64,
28 config: CheckpointConfig,
30 checkpoint_times: Mutex<VecDeque<Instant>>,
32}
33
34#[derive(Debug, Clone)]
36pub struct CheckpointConfig {
37 pub max_checkpoints_per_job: usize,
39 pub interval_chunks: usize,
41 pub interval_duration: Duration,
43 pub persist_to_disk: bool,
45 pub compress: bool,
47 pub verify_writes: bool,
49}
50
51impl Default for CheckpointConfig {
52 fn default() -> Self {
53 Self {
54 max_checkpoints_per_job: 5,
55 interval_chunks: 10,
56 interval_duration: Duration::from_secs(60),
57 persist_to_disk: true,
58 compress: false,
59 verify_writes: true,
60 }
61 }
62}
63
64#[derive(Debug, Clone)]
66pub struct Checkpoint<F: IntegrateFloat> {
67 pub id: u64,
69 pub job_id: JobId,
71 pub timestamp: SystemTime,
73 pub completed_chunks: Vec<ChunkCheckpoint<F>>,
75 pub in_progress_chunks: Vec<ChunkId>,
77 pub global_state: CheckpointGlobalState<F>,
79 pub validation_hash: u64,
81}
82
83#[derive(Debug, Clone)]
85pub struct ChunkCheckpoint<F: IntegrateFloat> {
86 pub chunk_id: ChunkId,
88 pub final_time: F,
90 pub final_state: Array1<F>,
92 pub final_derivative: Option<Array1<F>>,
94 pub node_id: NodeId,
96 pub processing_time: Duration,
98}
99
100#[derive(Debug, Clone, Default)]
102pub struct CheckpointGlobalState<F: IntegrateFloat> {
103 pub iteration: usize,
105 pub chunks_completed: usize,
107 pub chunks_remaining: usize,
109 pub current_time: F,
111 pub error_estimate: F,
113}
114
115impl<F: IntegrateFloat> CheckpointManager<F> {
116 pub fn new(storage_path: PathBuf, config: CheckpointConfig) -> DistributedResult<Self> {
118 if config.persist_to_disk {
120 fs::create_dir_all(&storage_path).map_err(|e| {
121 DistributedError::CheckpointError(format!(
122 "Failed to create checkpoint directory: {}",
123 e
124 ))
125 })?;
126 }
127
128 Ok(Self {
129 storage_path,
130 checkpoints: RwLock::new(HashMap::new()),
131 next_checkpoint_id: AtomicU64::new(1),
132 config,
133 checkpoint_times: Mutex::new(VecDeque::new()),
134 })
135 }
136
137 pub fn create_checkpoint(
139 &self,
140 job_id: JobId,
141 completed_chunks: Vec<ChunkResult<F>>,
142 in_progress_chunks: Vec<ChunkId>,
143 global_state: CheckpointGlobalState<F>,
144 ) -> DistributedResult<u64> {
145 let checkpoint_id = self.next_checkpoint_id.fetch_add(1, Ordering::SeqCst);
146
147 let chunk_checkpoints: Vec<ChunkCheckpoint<F>> = completed_chunks
149 .into_iter()
150 .filter(|r| r.status == ChunkResultStatus::Success)
151 .map(|r| ChunkCheckpoint {
152 chunk_id: r.chunk_id,
153 final_time: r.time_points.last().copied().unwrap_or(F::zero()),
154 final_state: r.final_state.clone(),
155 final_derivative: r.final_derivative.clone(),
156 node_id: r.node_id,
157 processing_time: r.processing_time,
158 })
159 .collect();
160
161 let validation_hash = self.calculate_hash(&chunk_checkpoints, &global_state);
163
164 let checkpoint = Checkpoint {
165 id: checkpoint_id,
166 job_id,
167 timestamp: SystemTime::now(),
168 completed_chunks: chunk_checkpoints,
169 in_progress_chunks,
170 global_state,
171 validation_hash,
172 };
173
174 {
176 let mut checkpoints = self.checkpoints.write().map_err(|_| {
177 DistributedError::CheckpointError("Failed to acquire checkpoint lock".to_string())
178 })?;
179
180 let job_checkpoints = checkpoints.entry(job_id).or_insert_with(Vec::new);
181 job_checkpoints.push(checkpoint.clone());
182
183 while job_checkpoints.len() > self.config.max_checkpoints_per_job {
185 let removed = job_checkpoints.remove(0);
186 if self.config.persist_to_disk {
187 let _ = self.delete_from_disk(job_id, removed.id);
188 }
189 }
190 }
191
192 if self.config.persist_to_disk {
194 self.save_to_disk(&checkpoint)?;
195 }
196
197 if let Ok(mut times) = self.checkpoint_times.lock() {
199 times.push_back(Instant::now());
200 while times.len() > 100 {
201 times.pop_front();
202 }
203 }
204
205 Ok(checkpoint_id)
206 }
207
208 pub fn get_latest_checkpoint(&self, job_id: JobId) -> Option<Checkpoint<F>> {
210 match self.checkpoints.read() {
211 Ok(checkpoints) => checkpoints.get(&job_id).and_then(|cps| cps.last().cloned()),
212 Err(_) => None,
213 }
214 }
215
216 pub fn get_checkpoint(&self, job_id: JobId, checkpoint_id: u64) -> Option<Checkpoint<F>> {
218 match self.checkpoints.read() {
219 Ok(checkpoints) => checkpoints
220 .get(&job_id)
221 .and_then(|cps| cps.iter().find(|cp| cp.id == checkpoint_id).cloned()),
222 Err(_) => None,
223 }
224 }
225
226 pub fn restore(
228 &self,
229 job_id: JobId,
230 checkpoint_id: Option<u64>,
231 ) -> DistributedResult<Checkpoint<F>> {
232 let checkpoint = if let Some(id) = checkpoint_id {
233 self.get_checkpoint(job_id, id)
234 } else {
235 self.get_latest_checkpoint(job_id)
236 };
237
238 let checkpoint = checkpoint.ok_or_else(|| {
239 DistributedError::CheckpointError(format!("No checkpoint found for job {:?}", job_id))
240 })?;
241
242 let expected_hash =
244 self.calculate_hash(&checkpoint.completed_chunks, &checkpoint.global_state);
245 if expected_hash != checkpoint.validation_hash {
246 return Err(DistributedError::CheckpointError(
247 "Checkpoint validation failed".to_string(),
248 ));
249 }
250
251 Ok(checkpoint)
252 }
253
254 pub fn cleanup_job(&self, job_id: JobId) -> DistributedResult<()> {
256 if let Ok(mut checkpoints) = self.checkpoints.write() {
257 if let Some(job_cps) = checkpoints.remove(&job_id) {
258 if self.config.persist_to_disk {
259 for cp in job_cps {
260 let _ = self.delete_from_disk(job_id, cp.id);
261 }
262 }
263 }
264 }
265 Ok(())
266 }
267
268 fn calculate_hash(
270 &self,
271 chunks: &[ChunkCheckpoint<F>],
272 global_state: &CheckpointGlobalState<F>,
273 ) -> u64 {
274 use std::collections::hash_map::DefaultHasher;
275 use std::hash::{Hash, Hasher};
276
277 let mut hasher = DefaultHasher::new();
278
279 for chunk in chunks {
281 chunk.chunk_id.0.hash(&mut hasher);
282 chunk.node_id.0.hash(&mut hasher);
283
284 for val in chunk.final_state.iter() {
286 let bits = val.to_f64().unwrap_or(0.0).to_bits();
287 bits.hash(&mut hasher);
288 }
289 }
290
291 global_state.iteration.hash(&mut hasher);
293 global_state.chunks_completed.hash(&mut hasher);
294 global_state.chunks_remaining.hash(&mut hasher);
295
296 hasher.finish()
297 }
298
299 fn save_to_disk(&self, checkpoint: &Checkpoint<F>) -> DistributedResult<()> {
301 let filename = format!(
302 "checkpoint_{}_{}.bin",
303 checkpoint.job_id.value(),
304 checkpoint.id
305 );
306 let path = self.storage_path.join(&filename);
307
308 let data = self.serialize_checkpoint(checkpoint)?;
310
311 let mut file = File::create(&path).map_err(|e| {
312 DistributedError::CheckpointError(format!("Failed to create checkpoint file: {}", e))
313 })?;
314
315 file.write_all(&data).map_err(|e| {
316 DistributedError::CheckpointError(format!("Failed to write checkpoint: {}", e))
317 })?;
318
319 if self.config.verify_writes {
321 let mut verify_file = File::open(&path).map_err(|e| {
322 DistributedError::CheckpointError(format!(
323 "Failed to verify checkpoint file: {}",
324 e
325 ))
326 })?;
327
328 let mut verify_data = Vec::new();
329 verify_file.read_to_end(&mut verify_data).map_err(|e| {
330 DistributedError::CheckpointError(format!("Failed to read back checkpoint: {}", e))
331 })?;
332
333 if verify_data != data {
334 return Err(DistributedError::CheckpointError(
335 "Checkpoint verification failed".to_string(),
336 ));
337 }
338 }
339
340 Ok(())
341 }
342
343 fn delete_from_disk(&self, job_id: JobId, checkpoint_id: u64) -> DistributedResult<()> {
345 let filename = format!("checkpoint_{}_{}.bin", job_id.value(), checkpoint_id);
346 let path = self.storage_path.join(&filename);
347
348 if path.exists() {
349 fs::remove_file(&path).map_err(|e| {
350 DistributedError::CheckpointError(format!(
351 "Failed to delete checkpoint file: {}",
352 e
353 ))
354 })?;
355 }
356
357 Ok(())
358 }
359
360 fn serialize_checkpoint(&self, checkpoint: &Checkpoint<F>) -> DistributedResult<Vec<u8>> {
362 let mut data = Vec::new();
363
364 data.extend_from_slice(&checkpoint.id.to_le_bytes());
366 data.extend_from_slice(&checkpoint.job_id.value().to_le_bytes());
367 data.extend_from_slice(&checkpoint.validation_hash.to_le_bytes());
368
369 let timestamp_secs = checkpoint
371 .timestamp
372 .duration_since(UNIX_EPOCH)
373 .unwrap_or(Duration::ZERO)
374 .as_secs();
375 data.extend_from_slice(×tamp_secs.to_le_bytes());
376
377 data.extend_from_slice(&checkpoint.global_state.iteration.to_le_bytes());
379 data.extend_from_slice(&checkpoint.global_state.chunks_completed.to_le_bytes());
380 data.extend_from_slice(&checkpoint.global_state.chunks_remaining.to_le_bytes());
381
382 data.extend_from_slice(&(checkpoint.completed_chunks.len() as u64).to_le_bytes());
384
385 for chunk in &checkpoint.completed_chunks {
387 data.extend_from_slice(&chunk.chunk_id.0.to_le_bytes());
388 data.extend_from_slice(&chunk.node_id.0.to_le_bytes());
389
390 let time_f64 = chunk.final_time.to_f64().unwrap_or(0.0);
391 data.extend_from_slice(&time_f64.to_le_bytes());
392
393 data.extend_from_slice(&(chunk.final_state.len() as u64).to_le_bytes());
395 for val in chunk.final_state.iter() {
396 let val_f64 = val.to_f64().unwrap_or(0.0);
397 data.extend_from_slice(&val_f64.to_le_bytes());
398 }
399 }
400
401 Ok(data)
402 }
403
404 pub fn should_checkpoint(&self, chunks_since_last: usize) -> bool {
406 if chunks_since_last >= self.config.interval_chunks {
408 return true;
409 }
410
411 if let Ok(times) = self.checkpoint_times.lock() {
413 if let Some(last_time) = times.back() {
414 if last_time.elapsed() >= self.config.interval_duration {
415 return true;
416 }
417 } else {
418 return chunks_since_last > 0;
420 }
421 }
422
423 false
424 }
425
426 pub fn get_statistics(&self) -> CheckpointStatistics {
428 let mut total_checkpoints = 0;
429 let mut total_chunks_saved = 0;
430
431 if let Ok(checkpoints) = self.checkpoints.read() {
432 for (_, job_cps) in checkpoints.iter() {
433 total_checkpoints += job_cps.len();
434 for cp in job_cps {
435 total_chunks_saved += cp.completed_chunks.len();
436 }
437 }
438 }
439
440 CheckpointStatistics {
441 total_checkpoints,
442 total_chunks_saved,
443 storage_path: self.storage_path.clone(),
444 }
445 }
446}
447
448#[derive(Debug, Clone)]
450pub struct CheckpointStatistics {
451 pub total_checkpoints: usize,
453 pub total_chunks_saved: usize,
455 pub storage_path: PathBuf,
457}
458
459pub struct FaultToleranceCoordinator<F: IntegrateFloat> {
461 checkpoint_manager: Arc<CheckpointManager<F>>,
463 mode: FaultToleranceMode,
465 failed_nodes: RwLock<HashSet<NodeId>>,
467 pending_retry: Mutex<Vec<ChunkId>>,
469 recovery_callbacks: RwLock<Vec<Arc<dyn Fn(JobId) + Send + Sync>>>,
471}
472
473impl<F: IntegrateFloat> FaultToleranceCoordinator<F> {
474 pub fn new(checkpoint_manager: Arc<CheckpointManager<F>>, mode: FaultToleranceMode) -> Self {
476 Self {
477 checkpoint_manager,
478 mode,
479 failed_nodes: RwLock::new(HashSet::new()),
480 pending_retry: Mutex::new(Vec::new()),
481 recovery_callbacks: RwLock::new(Vec::new()),
482 }
483 }
484
485 pub fn handle_node_failure(
487 &self,
488 node_id: NodeId,
489 affected_chunks: Vec<ChunkId>,
490 ) -> DistributedResult<RecoveryAction> {
491 if let Ok(mut failed) = self.failed_nodes.write() {
493 failed.insert(node_id);
494 }
495
496 match self.mode {
497 FaultToleranceMode::None => {
498 Err(DistributedError::NodeFailure(
500 node_id,
501 "Node failed, no fault tolerance enabled".to_string(),
502 ))
503 }
504 FaultToleranceMode::Standard => {
505 if let Ok(mut pending) = self.pending_retry.lock() {
507 pending.extend(affected_chunks.iter().cloned());
508 }
509 Ok(RecoveryAction::RetryChunks(affected_chunks))
510 }
511 FaultToleranceMode::HighAvailability => {
512 if let Ok(mut pending) = self.pending_retry.lock() {
514 pending.extend(affected_chunks.iter().cloned());
515 }
516 Ok(RecoveryAction::FailoverAndRetry(affected_chunks))
517 }
518 FaultToleranceMode::CheckpointRecovery => {
519 Ok(RecoveryAction::RestoreFromCheckpoint)
521 }
522 }
523 }
524
525 pub fn handle_chunk_failure(
527 &self,
528 chunk_id: ChunkId,
529 node_id: NodeId,
530 error: &str,
531 can_retry: bool,
532 ) -> DistributedResult<RecoveryAction> {
533 if can_retry && self.mode != FaultToleranceMode::None {
534 if let Ok(mut pending) = self.pending_retry.lock() {
535 pending.push(chunk_id);
536 }
537 Ok(RecoveryAction::RetryChunks(vec![chunk_id]))
538 } else if self.mode == FaultToleranceMode::CheckpointRecovery {
539 Ok(RecoveryAction::RestoreFromCheckpoint)
540 } else {
541 Err(DistributedError::ChunkError(
542 chunk_id,
543 format!("Unrecoverable error on node {}: {}", node_id, error),
544 ))
545 }
546 }
547
548 pub fn get_pending_retries(&self) -> Vec<ChunkId> {
550 match self.pending_retry.lock() {
551 Ok(pending) => pending.clone(),
552 Err(_) => Vec::new(),
553 }
554 }
555
556 pub fn clear_pending_retries(&self) -> Vec<ChunkId> {
558 match self.pending_retry.lock() {
559 Ok(mut pending) => std::mem::take(&mut *pending),
560 Err(_) => Vec::new(),
561 }
562 }
563
564 pub fn is_node_failed(&self, node_id: NodeId) -> bool {
566 match self.failed_nodes.read() {
567 Ok(failed) => failed.contains(&node_id),
568 Err(_) => false,
569 }
570 }
571
572 pub fn mark_node_recovered(&self, node_id: NodeId) {
574 if let Ok(mut failed) = self.failed_nodes.write() {
575 failed.remove(&node_id);
576 }
577 }
578
579 pub fn recover_job(&self, job_id: JobId) -> DistributedResult<Checkpoint<F>> {
581 let checkpoint = self.checkpoint_manager.restore(job_id, None)?;
582
583 if let Ok(callbacks) = self.recovery_callbacks.read() {
585 for cb in callbacks.iter() {
586 cb(job_id);
587 }
588 }
589
590 Ok(checkpoint)
591 }
592
593 pub fn on_recovery<F2>(&self, callback: F2)
595 where
596 F2: Fn(JobId) + Send + Sync + 'static,
597 {
598 if let Ok(mut callbacks) = self.recovery_callbacks.write() {
599 callbacks.push(Arc::new(callback));
600 }
601 }
602
603 pub fn failed_node_count(&self) -> usize {
605 match self.failed_nodes.read() {
606 Ok(failed) => failed.len(),
607 Err(_) => 0,
608 }
609 }
610}
611
612#[derive(Debug, Clone)]
614pub enum RecoveryAction {
615 RetryChunks(Vec<ChunkId>),
617 FailoverAndRetry(Vec<ChunkId>),
619 RestoreFromCheckpoint,
621 None,
623}
624
625#[cfg(test)]
626mod tests {
627 use super::*;
628
629 fn temp_storage_path() -> PathBuf {
630 std::env::temp_dir().join(format!("scirs_checkpoint_test_{}", std::process::id()))
631 }
632
633 #[test]
634 fn test_checkpoint_creation() {
635 let path = temp_storage_path();
636 let manager: CheckpointManager<f64> =
637 CheckpointManager::new(path.clone(), CheckpointConfig::default())
638 .expect("Failed to create manager");
639
640 let job_id = JobId::new(1);
641 let global_state = CheckpointGlobalState::default();
642
643 let checkpoint_id = manager
644 .create_checkpoint(job_id, Vec::new(), Vec::new(), global_state)
645 .expect("Failed to create checkpoint");
646
647 assert!(checkpoint_id > 0);
648
649 let checkpoint = manager.get_latest_checkpoint(job_id);
650 assert!(checkpoint.is_some());
651
652 let _ = fs::remove_dir_all(&path);
654 }
655
656 #[test]
657 fn test_checkpoint_restore() {
658 let path = temp_storage_path();
659 let mut config = CheckpointConfig::default();
660 config.persist_to_disk = false;
661
662 let manager: CheckpointManager<f64> =
663 CheckpointManager::new(path.clone(), config).expect("Failed to create manager");
664
665 let job_id = JobId::new(1);
666 let global_state = CheckpointGlobalState {
667 iteration: 5,
668 chunks_completed: 10,
669 ..Default::default()
670 };
671
672 let _ = manager.create_checkpoint(job_id, Vec::new(), Vec::new(), global_state.clone());
673
674 let restored = manager.restore(job_id, None).expect("Failed to restore");
675 assert_eq!(restored.global_state.iteration, 5);
676 assert_eq!(restored.global_state.chunks_completed, 10);
677
678 let _ = fs::remove_dir_all(&path);
680 }
681
682 #[test]
683 fn test_fault_tolerance_coordinator() {
684 let path = temp_storage_path();
685 let mut config = CheckpointConfig::default();
686 config.persist_to_disk = false;
687
688 let manager = Arc::new(
689 CheckpointManager::<f64>::new(path.clone(), config).expect("Failed to create manager"),
690 );
691
692 let coordinator = FaultToleranceCoordinator::new(manager, FaultToleranceMode::Standard);
693
694 let action = coordinator
695 .handle_node_failure(NodeId::new(1), vec![ChunkId::new(1), ChunkId::new(2)])
696 .expect("Failed to handle failure");
697
698 match action {
699 RecoveryAction::RetryChunks(chunks) => {
700 assert_eq!(chunks.len(), 2);
701 }
702 _ => panic!("Expected RetryChunks action"),
703 }
704
705 assert!(coordinator.is_node_failed(NodeId::new(1)));
706
707 let _ = fs::remove_dir_all(&path);
709 }
710}