Skip to main content

torsh_distributed/
fault_tolerance.rs

1//! Fault tolerance features for distributed training
2//!
3//! This module provides comprehensive fault tolerance capabilities including:
4//! - Checkpointing system for saving and restoring training state
5//! - Elastic training support for dynamic worker scaling
6//! - State synchronization during scaling events
7//! - Integration with error recovery mechanisms
8
9// Framework infrastructure - components designed for future use
10#![allow(dead_code)]
11use crate::error_recovery::{CircuitBreakerConfig, FailureDetector, RetryConfig};
12use crate::{TorshDistributedError, TorshResult};
13use serde::{Deserialize, Serialize};
14use std::collections::HashMap;
15use std::path::{Path, PathBuf};
16use std::sync::{Arc, Mutex, RwLock};
17use std::time::{Duration, SystemTime, UNIX_EPOCH};
18use tokio::fs;
19use torsh_nn::Parameter;
20use torsh_tensor::Tensor;
21use tracing::{debug, info, warn};
22
23/// Configuration for checkpointing
24#[derive(Debug, Clone)]
25pub struct CheckpointConfig {
26    /// Directory to save checkpoints
27    pub checkpoint_dir: PathBuf,
28    /// How frequently to save checkpoints (in steps)
29    pub checkpoint_frequency: usize,
30    /// Maximum number of checkpoints to keep
31    pub max_checkpoints: usize,
32    /// Whether to save async (non-blocking)
33    pub async_save: bool,
34    /// Compression level for checkpoint files (0-9)
35    pub compression_level: u8,
36    /// Whether to verify checkpoints after saving
37    pub verify_after_save: bool,
38}
39
40impl Default for CheckpointConfig {
41    fn default() -> Self {
42        Self {
43            checkpoint_dir: PathBuf::from("./checkpoints"),
44            checkpoint_frequency: 1000,
45            max_checkpoints: 5,
46            async_save: true,
47            compression_level: 3,
48            verify_after_save: true,
49        }
50    }
51}
52
53/// Configuration for elastic training
54#[derive(Debug, Clone)]
55pub struct ElasticConfig {
56    /// Minimum number of workers required
57    pub min_workers: usize,
58    /// Maximum number of workers allowed
59    pub max_workers: usize,
60    /// How long to wait for workers to join/leave before proceeding
61    pub scaling_timeout: Duration,
62    /// How frequently to check for scaling events
63    pub scaling_check_interval: Duration,
64    /// Whether to enable elastic scheduling
65    pub enable_elastic_scheduling: bool,
66    /// Rendezvous backend for worker coordination
67    pub rendezvous_backend: String,
68    /// Rendezvous endpoint
69    pub rendezvous_endpoint: String,
70}
71
72impl Default for ElasticConfig {
73    fn default() -> Self {
74        Self {
75            min_workers: 1,
76            max_workers: 64,
77            scaling_timeout: Duration::from_secs(300), // 5 minutes
78            scaling_check_interval: Duration::from_secs(30),
79            enable_elastic_scheduling: true,
80            rendezvous_backend: "etcd".to_string(),
81            rendezvous_endpoint: "localhost:2379".to_string(),
82        }
83    }
84}
85
86/// Training checkpoint containing all necessary state
87#[derive(Debug, Clone, Serialize, Deserialize)]
88pub struct TrainingCheckpoint {
89    /// Step number when checkpoint was created
90    pub step: usize,
91    /// Epoch number
92    pub epoch: usize,
93    /// Model parameters
94    pub model_state: HashMap<String, Vec<f32>>,
95    /// Optimizer state
96    pub optimizer_state: HashMap<String, Vec<f32>>,
97    /// Learning rate scheduler state
98    pub scheduler_state: HashMap<String, f32>,
99    /// Random number generator states
100    pub rng_states: HashMap<String, Vec<u8>>,
101    /// Loss value at checkpoint
102    pub loss: f32,
103    /// Metrics at checkpoint
104    pub metrics: HashMap<String, f32>,
105    /// Training configuration
106    pub config: HashMap<String, String>,
107    /// Timestamp when checkpoint was created
108    pub timestamp: u64,
109    /// Version identifier
110    pub version: String,
111    /// Distributed training metadata
112    pub distributed_meta: DistributedMetadata,
113}
114
115/// Metadata about the distributed training state
116#[derive(Debug, Clone, Serialize, Deserialize)]
117pub struct DistributedMetadata {
118    /// World size when checkpoint was created
119    pub world_size: usize,
120    /// Rank of the process that created the checkpoint
121    pub rank: usize,
122    /// Process group information
123    pub process_group_info: HashMap<String, String>,
124    /// Data parallel group size
125    pub dp_size: usize,
126    /// Tensor parallel group size
127    pub tp_size: usize,
128    /// Pipeline parallel group size
129    pub pp_size: usize,
130    /// FSDP sharding information
131    pub fsdp_sharding: HashMap<String, Vec<usize>>,
132}
133
134/// Events that can trigger elastic scaling
135#[derive(Debug, Clone, PartialEq)]
136pub enum ScalingEvent {
137    /// Worker failure detected
138    WorkerFailure { failed_ranks: Vec<usize> },
139    /// New workers joining
140    WorkerJoin { new_ranks: Vec<usize> },
141    /// Manual scaling request
142    ManualScale { target_workers: usize },
143    /// Automatic scaling based on load
144    AutoScale {
145        target_workers: usize,
146        reason: String,
147    },
148}
149
150/// State during elastic scaling
151#[derive(Debug, Clone)]
152pub enum ScalingState {
153    /// Normal training, no scaling
154    Stable,
155    /// Scaling in progress
156    Scaling {
157        event: ScalingEvent,
158        start_time: SystemTime,
159        expected_workers: usize,
160    },
161    /// Waiting for workers to synchronize
162    Synchronizing {
163        current_workers: usize,
164        target_workers: usize,
165    },
166}
167
168/// Checkpoint manager for saving and restoring training state
169#[derive(Debug)]
170pub struct CheckpointManager {
171    config: CheckpointConfig,
172    failure_detector: FailureDetector,
173    latest_checkpoint: Arc<RwLock<Option<TrainingCheckpoint>>>,
174    checkpoint_history: Arc<RwLock<Vec<PathBuf>>>,
175}
176
177impl CheckpointManager {
178    /// Create a new checkpoint manager
179    pub fn new(
180        config: CheckpointConfig,
181        health_check_interval: Duration,
182        health_timeout: Duration,
183    ) -> TorshResult<Self> {
184        // Ensure checkpoint directory exists
185        std::fs::create_dir_all(&config.checkpoint_dir).map_err(|e| {
186            TorshDistributedError::backend_error(
187                "checkpoint",
188                format!("Failed to create checkpoint directory: {}", e),
189            )
190        })?;
191
192        let retry_config = RetryConfig::default();
193        let circuit_breaker_config = CircuitBreakerConfig::default();
194
195        let failure_detector = FailureDetector::new(
196            health_check_interval,
197            health_timeout,
198            retry_config,
199            Some(circuit_breaker_config),
200        );
201
202        Ok(Self {
203            config,
204            failure_detector,
205            latest_checkpoint: Arc::new(RwLock::new(None)),
206            checkpoint_history: Arc::new(RwLock::new(Vec::new())),
207        })
208    }
209
210    /// Save a training checkpoint
211    pub async fn save_checkpoint(
212        &self,
213        checkpoint: TrainingCheckpoint,
214        rank: usize,
215    ) -> TorshResult<PathBuf> {
216        let checkpoint_path = self.config.checkpoint_dir.join(format!(
217            "checkpoint_step_{}_rank_{}.json",
218            checkpoint.step, rank
219        ));
220
221        info!(
222            "Saving checkpoint at step {} to {:?}",
223            checkpoint.step, checkpoint_path
224        );
225
226        let checkpoint_data = serde_json::to_string_pretty(&checkpoint).map_err(|e| {
227            TorshDistributedError::backend_error(
228                "checkpoint",
229                format!("Failed to serialize checkpoint: {}", e),
230            )
231        })?;
232
233        // Save with error recovery
234        self.failure_detector
235            .execute_with_recovery(
236                || async {
237                    fs::write(&checkpoint_path, &checkpoint_data)
238                        .await
239                        .map_err(|e| {
240                            TorshDistributedError::backend_error(
241                                "checkpoint",
242                                format!("Failed to write checkpoint: {}", e),
243                            )
244                        })
245                },
246                None,
247            )
248            .await?;
249
250        // Verify checkpoint if configured
251        if self.config.verify_after_save {
252            self.verify_checkpoint(&checkpoint_path).await?;
253        }
254
255        // Store checkpoint step before moving
256        let checkpoint_step = checkpoint.step;
257
258        // Update internal state
259        {
260            let mut latest = self
261                .latest_checkpoint
262                .write()
263                .expect("lock should not be poisoned");
264            *latest = Some(checkpoint);
265        }
266
267        {
268            let mut history = self
269                .checkpoint_history
270                .write()
271                .expect("lock should not be poisoned");
272            history.push(checkpoint_path.clone());
273
274            // Clean up old checkpoints
275            if history.len() > self.config.max_checkpoints {
276                let old_checkpoint = history.remove(0);
277                if let Err(e) = std::fs::remove_file(&old_checkpoint) {
278                    warn!(
279                        "Failed to remove old checkpoint {:?}: {}",
280                        old_checkpoint, e
281                    );
282                }
283            }
284        }
285
286        info!("Successfully saved checkpoint at step {}", checkpoint_step);
287        Ok(checkpoint_path)
288    }
289
290    /// Load the latest checkpoint
291    pub async fn load_latest_checkpoint(&self) -> TorshResult<Option<TrainingCheckpoint>> {
292        let checkpoint_files = self.find_checkpoint_files().await?;
293
294        if checkpoint_files.is_empty() {
295            info!("No checkpoints found");
296            return Ok(None);
297        }
298
299        // Find the latest checkpoint by step number
300        let latest_file = checkpoint_files
301            .iter()
302            .max_by_key(|path| self.extract_step_from_filename(path))
303            .expect("checkpoint_files should not be empty");
304
305        info!("Loading latest checkpoint from {:?}", latest_file);
306        self.load_checkpoint(latest_file).await
307    }
308
309    /// Load a specific checkpoint
310    pub async fn load_checkpoint(
311        &self,
312        checkpoint_path: &PathBuf,
313    ) -> TorshResult<Option<TrainingCheckpoint>> {
314        self.failure_detector
315            .execute_with_recovery(
316                || async {
317                    let checkpoint_data =
318                        fs::read_to_string(checkpoint_path).await.map_err(|e| {
319                            TorshDistributedError::backend_error(
320                                "checkpoint",
321                                format!("Failed to read checkpoint: {}", e),
322                            )
323                        })?;
324
325                    let checkpoint: TrainingCheckpoint = serde_json::from_str(&checkpoint_data)
326                        .map_err(|e| {
327                            TorshDistributedError::backend_error(
328                                "checkpoint",
329                                format!("Failed to deserialize checkpoint: {}", e),
330                            )
331                        })?;
332
333                    info!(
334                        "Successfully loaded checkpoint from step {}",
335                        checkpoint.step
336                    );
337                    Ok(Some(checkpoint))
338                },
339                None,
340            )
341            .await
342    }
343
344    /// Verify checkpoint integrity
345    async fn verify_checkpoint(&self, checkpoint_path: &PathBuf) -> TorshResult<()> {
346        debug!("Verifying checkpoint {:?}", checkpoint_path);
347
348        // Load and parse to ensure it's valid
349        let checkpoint = self.load_checkpoint(checkpoint_path).await?;
350        if checkpoint.is_none() {
351            return Err(TorshDistributedError::backend_error(
352                "checkpoint",
353                "Checkpoint verification failed: could not load",
354            ));
355        }
356
357        debug!("Checkpoint verification successful");
358        Ok(())
359    }
360
361    /// Find all checkpoint files in the directory
362    async fn find_checkpoint_files(&self) -> TorshResult<Vec<PathBuf>> {
363        let mut checkpoint_files = Vec::new();
364
365        let mut dir_entries = fs::read_dir(&self.config.checkpoint_dir)
366            .await
367            .map_err(|e| {
368                TorshDistributedError::backend_error(
369                    "checkpoint",
370                    format!("Failed to read checkpoint directory: {}", e),
371                )
372            })?;
373
374        while let Some(entry) = dir_entries.next_entry().await.map_err(|e| {
375            TorshDistributedError::backend_error(
376                "checkpoint",
377                format!("Failed to read directory entry: {}", e),
378            )
379        })? {
380            let path = entry.path();
381            if let Some(filename) = path.file_name() {
382                if filename.to_string_lossy().starts_with("checkpoint_")
383                    && filename.to_string_lossy().ends_with(".json")
384                {
385                    checkpoint_files.push(path);
386                }
387            }
388        }
389
390        Ok(checkpoint_files)
391    }
392
393    /// Extract step number from checkpoint filename
394    fn extract_step_from_filename(&self, path: &Path) -> usize {
395        if let Some(filename) = path.file_stem() {
396            let filename_str = filename.to_string_lossy();
397            if let Some(step_start) = filename_str.find("step_") {
398                let step_part = &filename_str[step_start + 5..];
399                if let Some(rank_pos) = step_part.find("_rank_") {
400                    let step_str = &step_part[..rank_pos];
401                    return step_str.parse().unwrap_or(0);
402                }
403            }
404        }
405        0
406    }
407
408    /// Get the latest checkpoint metadata without loading the full checkpoint
409    pub fn get_latest_checkpoint_info(&self) -> Option<TrainingCheckpoint> {
410        self.latest_checkpoint
411            .read()
412            .expect("lock should not be poisoned")
413            .clone()
414    }
415
416    /// Clean up all checkpoints (useful for cleanup)
417    pub async fn cleanup_all_checkpoints(&self) -> TorshResult<()> {
418        let checkpoint_files = self.find_checkpoint_files().await?;
419
420        for file in checkpoint_files {
421            if let Err(e) = fs::remove_file(&file).await {
422                warn!("Failed to remove checkpoint file {:?}: {}", file, e);
423            }
424        }
425
426        {
427            let mut history = self
428                .checkpoint_history
429                .write()
430                .expect("lock should not be poisoned");
431            history.clear();
432        }
433        {
434            let mut latest = self
435                .latest_checkpoint
436                .write()
437                .expect("lock should not be poisoned");
438            *latest = None;
439        }
440
441        info!("Cleaned up all checkpoints");
442        Ok(())
443    }
444}
445
446/// Elastic training manager for handling dynamic scaling
447#[derive(Debug)]
448pub struct ElasticTrainingManager {
449    config: ElasticConfig,
450    scaling_state: Arc<RwLock<ScalingState>>,
451    checkpoint_manager: CheckpointManager,
452    current_world_size: Arc<RwLock<usize>>,
453    worker_registry: Arc<RwLock<HashMap<usize, SystemTime>>>,
454    scaling_events: Arc<Mutex<Vec<ScalingEvent>>>,
455}
456
457impl ElasticTrainingManager {
458    /// Create a new elastic training manager
459    pub fn new(
460        config: ElasticConfig,
461        checkpoint_config: CheckpointConfig,
462        initial_world_size: usize,
463    ) -> TorshResult<Self> {
464        let checkpoint_manager = CheckpointManager::new(
465            checkpoint_config,
466            Duration::from_secs(30),  // health check interval
467            Duration::from_secs(120), // health timeout
468        )?;
469
470        Ok(Self {
471            config,
472            scaling_state: Arc::new(RwLock::new(ScalingState::Stable)),
473            checkpoint_manager,
474            current_world_size: Arc::new(RwLock::new(initial_world_size)),
475            worker_registry: Arc::new(RwLock::new(HashMap::new())),
476            scaling_events: Arc::new(Mutex::new(Vec::new())),
477        })
478    }
479
480    /// Check if scaling is needed and initiate if necessary
481    pub async fn check_scaling_needs(&self) -> TorshResult<Option<ScalingEvent>> {
482        let current_state = self
483            .scaling_state
484            .read()
485            .expect("lock should not be poisoned")
486            .clone();
487
488        match current_state {
489            ScalingState::Stable => {
490                // Check for worker failures or new joins
491                let current_workers = *self
492                    .current_world_size
493                    .read()
494                    .expect("lock should not be poisoned");
495
496                // Simulate failure detection (in real implementation, this would check actual worker health)
497                let failed_workers = self.detect_failed_workers().await?;
498                if !failed_workers.is_empty() {
499                    let event = ScalingEvent::WorkerFailure {
500                        failed_ranks: failed_workers,
501                    };
502                    info!("Detected worker failures, initiating scaling: {:?}", event);
503                    self.initiate_scaling(event.clone()).await?;
504                    return Ok(Some(event));
505                }
506
507                // Check for new workers
508                let new_workers = self.detect_new_workers().await?;
509                if !new_workers.is_empty() {
510                    let event = ScalingEvent::WorkerJoin {
511                        new_ranks: new_workers,
512                    };
513                    info!("Detected new workers, initiating scaling: {:?}", event);
514                    self.initiate_scaling(event.clone()).await?;
515                    return Ok(Some(event));
516                }
517
518                // Check for automatic scaling based on load (simplified)
519                if self.config.enable_elastic_scheduling {
520                    if let Some(target) = self.calculate_optimal_workers(current_workers).await? {
521                        if target != current_workers {
522                            let event = ScalingEvent::AutoScale {
523                                target_workers: target,
524                                reason: "Load-based scaling".to_string(),
525                            };
526                            info!("Initiating auto-scaling: {:?}", event);
527                            self.initiate_scaling(event.clone()).await?;
528                            return Ok(Some(event));
529                        }
530                    }
531                }
532            }
533            ScalingState::Scaling { .. } => {
534                // Check if scaling has completed
535                if self.is_scaling_complete().await? {
536                    self.complete_scaling().await?;
537                }
538            }
539            ScalingState::Synchronizing { .. } => {
540                // Check if synchronization has completed
541                if self.is_synchronization_complete().await? {
542                    self.complete_synchronization().await?;
543                }
544            }
545        }
546
547        Ok(None)
548    }
549
550    /// Initiate scaling process
551    async fn initiate_scaling(&self, event: ScalingEvent) -> TorshResult<()> {
552        info!("Initiating scaling for event: {:?}", event);
553
554        // Save checkpoint before scaling
555        if let Some(checkpoint) = self.checkpoint_manager.get_latest_checkpoint_info() {
556            self.checkpoint_manager
557                .save_checkpoint(checkpoint, 0)
558                .await?;
559        }
560
561        let expected_workers = match &event {
562            ScalingEvent::WorkerFailure { failed_ranks } => {
563                *self
564                    .current_world_size
565                    .read()
566                    .expect("lock should not be poisoned")
567                    - failed_ranks.len()
568            }
569            ScalingEvent::WorkerJoin { new_ranks } => {
570                *self
571                    .current_world_size
572                    .read()
573                    .expect("lock should not be poisoned")
574                    + new_ranks.len()
575            }
576            ScalingEvent::ManualScale { target_workers }
577            | ScalingEvent::AutoScale { target_workers, .. } => *target_workers,
578        };
579
580        // Ensure we stay within bounds
581        let expected_workers = expected_workers
582            .max(self.config.min_workers)
583            .min(self.config.max_workers);
584
585        {
586            let mut state = self
587                .scaling_state
588                .write()
589                .expect("lock should not be poisoned");
590            *state = ScalingState::Scaling {
591                event: event.clone(),
592                start_time: SystemTime::now(),
593                expected_workers,
594            };
595        }
596
597        // Add to event history
598        {
599            let mut events = self
600                .scaling_events
601                .lock()
602                .expect("lock should not be poisoned");
603            events.push(event);
604            // Keep only recent events
605            if events.len() > 100 {
606                events.drain(0..50);
607            }
608        }
609
610        Ok(())
611    }
612
613    /// Check if scaling is complete
614    async fn is_scaling_complete(&self) -> TorshResult<bool> {
615        // Simplified: check if enough time has passed
616        if let ScalingState::Scaling { start_time, .. } = *self
617            .scaling_state
618            .read()
619            .expect("lock should not be poisoned")
620        {
621            Ok(start_time.elapsed().unwrap_or(Duration::ZERO) >= self.config.scaling_timeout)
622        } else {
623            Ok(false)
624        }
625    }
626
627    /// Complete the scaling process
628    async fn complete_scaling(&self) -> TorshResult<()> {
629        info!("Completing scaling process");
630
631        let expected_workers = if let ScalingState::Scaling {
632            expected_workers, ..
633        } = *self
634            .scaling_state
635            .read()
636            .expect("lock should not be poisoned")
637        {
638            expected_workers
639        } else {
640            return Ok(());
641        };
642
643        // Transition to synchronization state
644        {
645            let mut state = self
646                .scaling_state
647                .write()
648                .expect("lock should not be poisoned");
649            *state = ScalingState::Synchronizing {
650                current_workers: *self
651                    .current_world_size
652                    .read()
653                    .expect("lock should not be poisoned"),
654                target_workers: expected_workers,
655            };
656        }
657
658        info!("Transitioning to synchronization phase");
659        Ok(())
660    }
661
662    /// Check if synchronization is complete
663    async fn is_synchronization_complete(&self) -> TorshResult<bool> {
664        // Simplified: assume synchronization completes quickly
665        Ok(true)
666    }
667
668    /// Complete the synchronization process
669    async fn complete_synchronization(&self) -> TorshResult<()> {
670        info!("Completing synchronization process");
671
672        let target_workers = if let ScalingState::Synchronizing { target_workers, .. } = *self
673            .scaling_state
674            .read()
675            .expect("lock should not be poisoned")
676        {
677            target_workers
678        } else {
679            return Ok(());
680        };
681
682        // Update world size
683        {
684            let mut world_size = self
685                .current_world_size
686                .write()
687                .expect("lock should not be poisoned");
688            *world_size = target_workers;
689        }
690
691        // Return to stable state
692        {
693            let mut state = self
694                .scaling_state
695                .write()
696                .expect("lock should not be poisoned");
697            *state = ScalingState::Stable;
698        }
699
700        info!(
701            "Elastic scaling completed, new world size: {}",
702            target_workers
703        );
704        Ok(())
705    }
706
707    /// Detect failed workers (simplified implementation)
708    async fn detect_failed_workers(&self) -> TorshResult<Vec<usize>> {
709        // In a real implementation, this would check actual worker health
710        // For now, return empty to simulate no failures
711        Ok(Vec::new())
712    }
713
714    /// Detect new workers (simplified implementation)
715    async fn detect_new_workers(&self) -> TorshResult<Vec<usize>> {
716        // In a real implementation, this would check for new worker registrations
717        // For now, return empty to simulate no new workers
718        Ok(Vec::new())
719    }
720
721    /// Calculate optimal number of workers based on current load
722    async fn calculate_optimal_workers(
723        &self,
724        _current_workers: usize,
725    ) -> TorshResult<Option<usize>> {
726        // Simplified auto-scaling logic
727        // In a real implementation, this would consider:
728        // - Current throughput
729        // - Resource utilization
730        // - Queue lengths
731        // - Performance metrics
732
733        // For now, maintain current workers
734        Ok(None)
735    }
736
737    /// Get current scaling state
738    pub fn get_scaling_state(&self) -> ScalingState {
739        self.scaling_state
740            .read()
741            .expect("lock should not be poisoned")
742            .clone()
743    }
744
745    /// Get current world size
746    pub fn get_world_size(&self) -> usize {
747        *self
748            .current_world_size
749            .read()
750            .expect("lock should not be poisoned")
751    }
752
753    /// Force manual scaling
754    pub async fn scale_to(&self, target_workers: usize) -> TorshResult<()> {
755        let event = ScalingEvent::ManualScale { target_workers };
756        self.initiate_scaling(event).await
757    }
758
759    /// Get scaling event history
760    pub fn get_scaling_history(&self) -> Vec<ScalingEvent> {
761        self.scaling_events
762            .lock()
763            .expect("lock should not be poisoned")
764            .clone()
765    }
766
767    /// Check if training can proceed (not currently scaling)
768    pub fn can_proceed_training(&self) -> bool {
769        matches!(
770            *self
771                .scaling_state
772                .read()
773                .expect("lock should not be poisoned"),
774            ScalingState::Stable
775        )
776    }
777
778    /// Get checkpoint manager reference
779    pub fn checkpoint_manager(&self) -> &CheckpointManager {
780        &self.checkpoint_manager
781    }
782}
783
784/// Utility functions for creating training checkpoints
785pub mod checkpoint_utils {
786    use super::*;
787
788    /// Parameters for creating a checkpoint
789    #[allow(dead_code)]
790    pub struct CheckpointParams {
791        pub step: usize,
792        pub epoch: usize,
793        pub model_params: HashMap<String, Parameter>,
794        pub optimizer_state: HashMap<String, Tensor>,
795        pub loss: f32,
796        pub metrics: HashMap<String, f32>,
797        pub world_size: usize,
798        pub rank: usize,
799    }
800
801    /// Create a checkpoint from model parameters and training state
802    pub fn create_checkpoint(params: CheckpointParams) -> TorshResult<TrainingCheckpoint> {
803        let CheckpointParams {
804            step,
805            epoch,
806            model_params,
807            optimizer_state,
808            loss,
809            metrics,
810            world_size,
811            rank,
812        } = params;
813        // Convert model parameters to serializable format
814        let mut model_state = HashMap::new();
815        for (name, param) in model_params {
816            let tensor = param.tensor();
817            let tensor_guard = tensor.read();
818            let data = tensor_guard.flatten()?.to_vec()?;
819            model_state.insert(name, data);
820        }
821
822        // Convert optimizer state to serializable format
823        let mut opt_state = HashMap::new();
824        for (name, tensor) in optimizer_state {
825            let data = tensor.flatten()?.to_vec()?;
826            opt_state.insert(name, data);
827        }
828
829        let distributed_meta = DistributedMetadata {
830            world_size,
831            rank,
832            process_group_info: HashMap::new(),
833            dp_size: world_size, // Simplified
834            tp_size: 1,
835            pp_size: 1,
836            fsdp_sharding: HashMap::new(),
837        };
838
839        Ok(TrainingCheckpoint {
840            step,
841            epoch,
842            model_state,
843            optimizer_state: opt_state,
844            scheduler_state: HashMap::new(),
845            rng_states: HashMap::new(),
846            loss,
847            metrics,
848            config: HashMap::new(),
849            timestamp: SystemTime::now()
850                .duration_since(UNIX_EPOCH)
851                .expect("time should be after UNIX_EPOCH")
852                .as_secs(),
853            version: "1.0.0".to_string(),
854            distributed_meta,
855        })
856    }
857
858    /// Restore model parameters from checkpoint
859    pub fn restore_model_from_checkpoint(
860        checkpoint: &TrainingCheckpoint,
861    ) -> TorshResult<HashMap<String, Tensor>> {
862        let mut model_params = HashMap::new();
863
864        for (name, data) in &checkpoint.model_state {
865            let shape = vec![data.len()]; // Simplified shape reconstruction
866            let tensor = Tensor::from_vec(data.clone(), &shape)?;
867            model_params.insert(name.clone(), tensor);
868        }
869
870        Ok(model_params)
871    }
872
873    /// Restore optimizer state from checkpoint
874    pub fn restore_optimizer_from_checkpoint(
875        checkpoint: &TrainingCheckpoint,
876    ) -> TorshResult<HashMap<String, Tensor>> {
877        let mut optimizer_state = HashMap::new();
878
879        for (name, data) in &checkpoint.optimizer_state {
880            let shape = vec![data.len()]; // Simplified shape reconstruction
881            let tensor = Tensor::from_vec(data.clone(), &shape)?;
882            optimizer_state.insert(name.clone(), tensor);
883        }
884
885        Ok(optimizer_state)
886    }
887}
888
889#[cfg(test)]
890mod tests {
891    use super::*;
892    use tempfile::TempDir;
893
894    #[tokio::test]
895    async fn test_checkpoint_manager() -> TorshResult<()> {
896        let temp_dir = TempDir::new().unwrap();
897        let config = CheckpointConfig {
898            checkpoint_dir: temp_dir.path().to_path_buf(),
899            checkpoint_frequency: 100,
900            max_checkpoints: 3,
901            ..Default::default()
902        };
903
904        let manager = CheckpointManager::new(
905            config,
906            Duration::from_millis(100),
907            Duration::from_millis(200),
908        )?;
909
910        // Create a test checkpoint
911        let checkpoint = TrainingCheckpoint {
912            step: 1000,
913            epoch: 10,
914            model_state: {
915                let mut state = HashMap::new();
916                state.insert("weight".to_string(), vec![1.0, 2.0, 3.0]);
917                state
918            },
919            optimizer_state: HashMap::new(),
920            scheduler_state: HashMap::new(),
921            rng_states: HashMap::new(),
922            loss: 0.5,
923            metrics: HashMap::new(),
924            config: HashMap::new(),
925            timestamp: SystemTime::now()
926                .duration_since(UNIX_EPOCH)
927                .expect("time should be after UNIX_EPOCH")
928                .as_secs(),
929            version: "1.0.0".to_string(),
930            distributed_meta: DistributedMetadata {
931                world_size: 4,
932                rank: 0,
933                process_group_info: HashMap::new(),
934                dp_size: 4,
935                tp_size: 1,
936                pp_size: 1,
937                fsdp_sharding: HashMap::new(),
938            },
939        };
940
941        // Save checkpoint
942        let checkpoint_path = manager.save_checkpoint(checkpoint.clone(), 0).await?;
943        assert!(checkpoint_path.exists());
944
945        // Load checkpoint
946        let loaded = manager.load_latest_checkpoint().await?;
947        assert!(loaded.is_some());
948        let loaded_checkpoint = loaded.unwrap();
949        assert_eq!(loaded_checkpoint.step, checkpoint.step);
950        assert_eq!(loaded_checkpoint.loss, checkpoint.loss);
951
952        Ok(())
953    }
954
955    #[tokio::test]
956    async fn test_elastic_training_manager() -> TorshResult<()> {
957        let temp_dir = TempDir::new().unwrap();
958        let elastic_config = ElasticConfig {
959            min_workers: 2,
960            max_workers: 8,
961            scaling_timeout: Duration::from_millis(100),
962            ..Default::default()
963        };
964        let checkpoint_config = CheckpointConfig {
965            checkpoint_dir: temp_dir.path().to_path_buf(),
966            ..Default::default()
967        };
968
969        let manager = ElasticTrainingManager::new(
970            elastic_config,
971            checkpoint_config,
972            4, // initial world size
973        )?;
974
975        assert_eq!(manager.get_world_size(), 4);
976        assert!(manager.can_proceed_training());
977
978        // Test manual scaling
979        manager.scale_to(6).await?;
980
981        // Initially should be in scaling state
982        match manager.get_scaling_state() {
983            ScalingState::Scaling {
984                expected_workers, ..
985            } => {
986                assert_eq!(expected_workers, 6);
987            }
988            _ => panic!("Expected scaling state"),
989        }
990
991        Ok(())
992    }
993
994    #[test]
995    fn test_checkpoint_config() {
996        let config = CheckpointConfig::default();
997        assert_eq!(config.checkpoint_frequency, 1000);
998        assert_eq!(config.max_checkpoints, 5);
999        assert!(config.async_save);
1000    }
1001
1002    #[test]
1003    fn test_elastic_config() {
1004        let config = ElasticConfig::default();
1005        assert_eq!(config.min_workers, 1);
1006        assert_eq!(config.max_workers, 64);
1007        assert!(config.enable_elastic_scheduling);
1008    }
1009
1010    #[test]
1011    fn test_scaling_events() {
1012        let event1 = ScalingEvent::WorkerFailure {
1013            failed_ranks: vec![1, 2],
1014        };
1015        let event2 = ScalingEvent::WorkerJoin {
1016            new_ranks: vec![5, 6],
1017        };
1018        let event3 = ScalingEvent::ManualScale { target_workers: 8 };
1019
1020        assert_ne!(event1, event2);
1021        assert_ne!(event2, event3);
1022    }
1023}