1#![allow(dead_code)]
11use crate::error_recovery::{CircuitBreakerConfig, FailureDetector, RetryConfig};
12use crate::{TorshDistributedError, TorshResult};
13use serde::{Deserialize, Serialize};
14use std::collections::HashMap;
15use std::path::{Path, PathBuf};
16use std::sync::{Arc, Mutex, RwLock};
17use std::time::{Duration, SystemTime, UNIX_EPOCH};
18use tokio::fs;
19use torsh_nn::Parameter;
20use torsh_tensor::Tensor;
21use tracing::{debug, info, warn};
22
23#[derive(Debug, Clone)]
25pub struct CheckpointConfig {
26 pub checkpoint_dir: PathBuf,
28 pub checkpoint_frequency: usize,
30 pub max_checkpoints: usize,
32 pub async_save: bool,
34 pub compression_level: u8,
36 pub verify_after_save: bool,
38}
39
40impl Default for CheckpointConfig {
41 fn default() -> Self {
42 Self {
43 checkpoint_dir: PathBuf::from("./checkpoints"),
44 checkpoint_frequency: 1000,
45 max_checkpoints: 5,
46 async_save: true,
47 compression_level: 3,
48 verify_after_save: true,
49 }
50 }
51}
52
53#[derive(Debug, Clone)]
55pub struct ElasticConfig {
56 pub min_workers: usize,
58 pub max_workers: usize,
60 pub scaling_timeout: Duration,
62 pub scaling_check_interval: Duration,
64 pub enable_elastic_scheduling: bool,
66 pub rendezvous_backend: String,
68 pub rendezvous_endpoint: String,
70}
71
72impl Default for ElasticConfig {
73 fn default() -> Self {
74 Self {
75 min_workers: 1,
76 max_workers: 64,
77 scaling_timeout: Duration::from_secs(300), scaling_check_interval: Duration::from_secs(30),
79 enable_elastic_scheduling: true,
80 rendezvous_backend: "etcd".to_string(),
81 rendezvous_endpoint: "localhost:2379".to_string(),
82 }
83 }
84}
85
86#[derive(Debug, Clone, Serialize, Deserialize)]
88pub struct TrainingCheckpoint {
89 pub step: usize,
91 pub epoch: usize,
93 pub model_state: HashMap<String, Vec<f32>>,
95 pub optimizer_state: HashMap<String, Vec<f32>>,
97 pub scheduler_state: HashMap<String, f32>,
99 pub rng_states: HashMap<String, Vec<u8>>,
101 pub loss: f32,
103 pub metrics: HashMap<String, f32>,
105 pub config: HashMap<String, String>,
107 pub timestamp: u64,
109 pub version: String,
111 pub distributed_meta: DistributedMetadata,
113}
114
115#[derive(Debug, Clone, Serialize, Deserialize)]
117pub struct DistributedMetadata {
118 pub world_size: usize,
120 pub rank: usize,
122 pub process_group_info: HashMap<String, String>,
124 pub dp_size: usize,
126 pub tp_size: usize,
128 pub pp_size: usize,
130 pub fsdp_sharding: HashMap<String, Vec<usize>>,
132}
133
134#[derive(Debug, Clone, PartialEq)]
136pub enum ScalingEvent {
137 WorkerFailure { failed_ranks: Vec<usize> },
139 WorkerJoin { new_ranks: Vec<usize> },
141 ManualScale { target_workers: usize },
143 AutoScale {
145 target_workers: usize,
146 reason: String,
147 },
148}
149
150#[derive(Debug, Clone)]
152pub enum ScalingState {
153 Stable,
155 Scaling {
157 event: ScalingEvent,
158 start_time: SystemTime,
159 expected_workers: usize,
160 },
161 Synchronizing {
163 current_workers: usize,
164 target_workers: usize,
165 },
166}
167
168#[derive(Debug)]
170pub struct CheckpointManager {
171 config: CheckpointConfig,
172 failure_detector: FailureDetector,
173 latest_checkpoint: Arc<RwLock<Option<TrainingCheckpoint>>>,
174 checkpoint_history: Arc<RwLock<Vec<PathBuf>>>,
175}
176
177impl CheckpointManager {
178 pub fn new(
180 config: CheckpointConfig,
181 health_check_interval: Duration,
182 health_timeout: Duration,
183 ) -> TorshResult<Self> {
184 std::fs::create_dir_all(&config.checkpoint_dir).map_err(|e| {
186 TorshDistributedError::backend_error(
187 "checkpoint",
188 format!("Failed to create checkpoint directory: {}", e),
189 )
190 })?;
191
192 let retry_config = RetryConfig::default();
193 let circuit_breaker_config = CircuitBreakerConfig::default();
194
195 let failure_detector = FailureDetector::new(
196 health_check_interval,
197 health_timeout,
198 retry_config,
199 Some(circuit_breaker_config),
200 );
201
202 Ok(Self {
203 config,
204 failure_detector,
205 latest_checkpoint: Arc::new(RwLock::new(None)),
206 checkpoint_history: Arc::new(RwLock::new(Vec::new())),
207 })
208 }
209
210 pub async fn save_checkpoint(
212 &self,
213 checkpoint: TrainingCheckpoint,
214 rank: usize,
215 ) -> TorshResult<PathBuf> {
216 let checkpoint_path = self.config.checkpoint_dir.join(format!(
217 "checkpoint_step_{}_rank_{}.json",
218 checkpoint.step, rank
219 ));
220
221 info!(
222 "Saving checkpoint at step {} to {:?}",
223 checkpoint.step, checkpoint_path
224 );
225
226 let checkpoint_data = serde_json::to_string_pretty(&checkpoint).map_err(|e| {
227 TorshDistributedError::backend_error(
228 "checkpoint",
229 format!("Failed to serialize checkpoint: {}", e),
230 )
231 })?;
232
233 self.failure_detector
235 .execute_with_recovery(
236 || async {
237 fs::write(&checkpoint_path, &checkpoint_data)
238 .await
239 .map_err(|e| {
240 TorshDistributedError::backend_error(
241 "checkpoint",
242 format!("Failed to write checkpoint: {}", e),
243 )
244 })
245 },
246 None,
247 )
248 .await?;
249
250 if self.config.verify_after_save {
252 self.verify_checkpoint(&checkpoint_path).await?;
253 }
254
255 let checkpoint_step = checkpoint.step;
257
258 {
260 let mut latest = self
261 .latest_checkpoint
262 .write()
263 .expect("lock should not be poisoned");
264 *latest = Some(checkpoint);
265 }
266
267 {
268 let mut history = self
269 .checkpoint_history
270 .write()
271 .expect("lock should not be poisoned");
272 history.push(checkpoint_path.clone());
273
274 if history.len() > self.config.max_checkpoints {
276 let old_checkpoint = history.remove(0);
277 if let Err(e) = std::fs::remove_file(&old_checkpoint) {
278 warn!(
279 "Failed to remove old checkpoint {:?}: {}",
280 old_checkpoint, e
281 );
282 }
283 }
284 }
285
286 info!("Successfully saved checkpoint at step {}", checkpoint_step);
287 Ok(checkpoint_path)
288 }
289
290 pub async fn load_latest_checkpoint(&self) -> TorshResult<Option<TrainingCheckpoint>> {
292 let checkpoint_files = self.find_checkpoint_files().await?;
293
294 if checkpoint_files.is_empty() {
295 info!("No checkpoints found");
296 return Ok(None);
297 }
298
299 let latest_file = checkpoint_files
301 .iter()
302 .max_by_key(|path| self.extract_step_from_filename(path))
303 .expect("checkpoint_files should not be empty");
304
305 info!("Loading latest checkpoint from {:?}", latest_file);
306 self.load_checkpoint(latest_file).await
307 }
308
309 pub async fn load_checkpoint(
311 &self,
312 checkpoint_path: &PathBuf,
313 ) -> TorshResult<Option<TrainingCheckpoint>> {
314 self.failure_detector
315 .execute_with_recovery(
316 || async {
317 let checkpoint_data =
318 fs::read_to_string(checkpoint_path).await.map_err(|e| {
319 TorshDistributedError::backend_error(
320 "checkpoint",
321 format!("Failed to read checkpoint: {}", e),
322 )
323 })?;
324
325 let checkpoint: TrainingCheckpoint = serde_json::from_str(&checkpoint_data)
326 .map_err(|e| {
327 TorshDistributedError::backend_error(
328 "checkpoint",
329 format!("Failed to deserialize checkpoint: {}", e),
330 )
331 })?;
332
333 info!(
334 "Successfully loaded checkpoint from step {}",
335 checkpoint.step
336 );
337 Ok(Some(checkpoint))
338 },
339 None,
340 )
341 .await
342 }
343
344 async fn verify_checkpoint(&self, checkpoint_path: &PathBuf) -> TorshResult<()> {
346 debug!("Verifying checkpoint {:?}", checkpoint_path);
347
348 let checkpoint = self.load_checkpoint(checkpoint_path).await?;
350 if checkpoint.is_none() {
351 return Err(TorshDistributedError::backend_error(
352 "checkpoint",
353 "Checkpoint verification failed: could not load",
354 ));
355 }
356
357 debug!("Checkpoint verification successful");
358 Ok(())
359 }
360
361 async fn find_checkpoint_files(&self) -> TorshResult<Vec<PathBuf>> {
363 let mut checkpoint_files = Vec::new();
364
365 let mut dir_entries = fs::read_dir(&self.config.checkpoint_dir)
366 .await
367 .map_err(|e| {
368 TorshDistributedError::backend_error(
369 "checkpoint",
370 format!("Failed to read checkpoint directory: {}", e),
371 )
372 })?;
373
374 while let Some(entry) = dir_entries.next_entry().await.map_err(|e| {
375 TorshDistributedError::backend_error(
376 "checkpoint",
377 format!("Failed to read directory entry: {}", e),
378 )
379 })? {
380 let path = entry.path();
381 if let Some(filename) = path.file_name() {
382 if filename.to_string_lossy().starts_with("checkpoint_")
383 && filename.to_string_lossy().ends_with(".json")
384 {
385 checkpoint_files.push(path);
386 }
387 }
388 }
389
390 Ok(checkpoint_files)
391 }
392
393 fn extract_step_from_filename(&self, path: &Path) -> usize {
395 if let Some(filename) = path.file_stem() {
396 let filename_str = filename.to_string_lossy();
397 if let Some(step_start) = filename_str.find("step_") {
398 let step_part = &filename_str[step_start + 5..];
399 if let Some(rank_pos) = step_part.find("_rank_") {
400 let step_str = &step_part[..rank_pos];
401 return step_str.parse().unwrap_or(0);
402 }
403 }
404 }
405 0
406 }
407
408 pub fn get_latest_checkpoint_info(&self) -> Option<TrainingCheckpoint> {
410 self.latest_checkpoint
411 .read()
412 .expect("lock should not be poisoned")
413 .clone()
414 }
415
416 pub async fn cleanup_all_checkpoints(&self) -> TorshResult<()> {
418 let checkpoint_files = self.find_checkpoint_files().await?;
419
420 for file in checkpoint_files {
421 if let Err(e) = fs::remove_file(&file).await {
422 warn!("Failed to remove checkpoint file {:?}: {}", file, e);
423 }
424 }
425
426 {
427 let mut history = self
428 .checkpoint_history
429 .write()
430 .expect("lock should not be poisoned");
431 history.clear();
432 }
433 {
434 let mut latest = self
435 .latest_checkpoint
436 .write()
437 .expect("lock should not be poisoned");
438 *latest = None;
439 }
440
441 info!("Cleaned up all checkpoints");
442 Ok(())
443 }
444}
445
446#[derive(Debug)]
448pub struct ElasticTrainingManager {
449 config: ElasticConfig,
450 scaling_state: Arc<RwLock<ScalingState>>,
451 checkpoint_manager: CheckpointManager,
452 current_world_size: Arc<RwLock<usize>>,
453 worker_registry: Arc<RwLock<HashMap<usize, SystemTime>>>,
454 scaling_events: Arc<Mutex<Vec<ScalingEvent>>>,
455}
456
457impl ElasticTrainingManager {
458 pub fn new(
460 config: ElasticConfig,
461 checkpoint_config: CheckpointConfig,
462 initial_world_size: usize,
463 ) -> TorshResult<Self> {
464 let checkpoint_manager = CheckpointManager::new(
465 checkpoint_config,
466 Duration::from_secs(30), Duration::from_secs(120), )?;
469
470 Ok(Self {
471 config,
472 scaling_state: Arc::new(RwLock::new(ScalingState::Stable)),
473 checkpoint_manager,
474 current_world_size: Arc::new(RwLock::new(initial_world_size)),
475 worker_registry: Arc::new(RwLock::new(HashMap::new())),
476 scaling_events: Arc::new(Mutex::new(Vec::new())),
477 })
478 }
479
480 pub async fn check_scaling_needs(&self) -> TorshResult<Option<ScalingEvent>> {
482 let current_state = self
483 .scaling_state
484 .read()
485 .expect("lock should not be poisoned")
486 .clone();
487
488 match current_state {
489 ScalingState::Stable => {
490 let current_workers = *self
492 .current_world_size
493 .read()
494 .expect("lock should not be poisoned");
495
496 let failed_workers = self.detect_failed_workers().await?;
498 if !failed_workers.is_empty() {
499 let event = ScalingEvent::WorkerFailure {
500 failed_ranks: failed_workers,
501 };
502 info!("Detected worker failures, initiating scaling: {:?}", event);
503 self.initiate_scaling(event.clone()).await?;
504 return Ok(Some(event));
505 }
506
507 let new_workers = self.detect_new_workers().await?;
509 if !new_workers.is_empty() {
510 let event = ScalingEvent::WorkerJoin {
511 new_ranks: new_workers,
512 };
513 info!("Detected new workers, initiating scaling: {:?}", event);
514 self.initiate_scaling(event.clone()).await?;
515 return Ok(Some(event));
516 }
517
518 if self.config.enable_elastic_scheduling {
520 if let Some(target) = self.calculate_optimal_workers(current_workers).await? {
521 if target != current_workers {
522 let event = ScalingEvent::AutoScale {
523 target_workers: target,
524 reason: "Load-based scaling".to_string(),
525 };
526 info!("Initiating auto-scaling: {:?}", event);
527 self.initiate_scaling(event.clone()).await?;
528 return Ok(Some(event));
529 }
530 }
531 }
532 }
533 ScalingState::Scaling { .. } => {
534 if self.is_scaling_complete().await? {
536 self.complete_scaling().await?;
537 }
538 }
539 ScalingState::Synchronizing { .. } => {
540 if self.is_synchronization_complete().await? {
542 self.complete_synchronization().await?;
543 }
544 }
545 }
546
547 Ok(None)
548 }
549
550 async fn initiate_scaling(&self, event: ScalingEvent) -> TorshResult<()> {
552 info!("Initiating scaling for event: {:?}", event);
553
554 if let Some(checkpoint) = self.checkpoint_manager.get_latest_checkpoint_info() {
556 self.checkpoint_manager
557 .save_checkpoint(checkpoint, 0)
558 .await?;
559 }
560
561 let expected_workers = match &event {
562 ScalingEvent::WorkerFailure { failed_ranks } => {
563 *self
564 .current_world_size
565 .read()
566 .expect("lock should not be poisoned")
567 - failed_ranks.len()
568 }
569 ScalingEvent::WorkerJoin { new_ranks } => {
570 *self
571 .current_world_size
572 .read()
573 .expect("lock should not be poisoned")
574 + new_ranks.len()
575 }
576 ScalingEvent::ManualScale { target_workers }
577 | ScalingEvent::AutoScale { target_workers, .. } => *target_workers,
578 };
579
580 let expected_workers = expected_workers
582 .max(self.config.min_workers)
583 .min(self.config.max_workers);
584
585 {
586 let mut state = self
587 .scaling_state
588 .write()
589 .expect("lock should not be poisoned");
590 *state = ScalingState::Scaling {
591 event: event.clone(),
592 start_time: SystemTime::now(),
593 expected_workers,
594 };
595 }
596
597 {
599 let mut events = self
600 .scaling_events
601 .lock()
602 .expect("lock should not be poisoned");
603 events.push(event);
604 if events.len() > 100 {
606 events.drain(0..50);
607 }
608 }
609
610 Ok(())
611 }
612
613 async fn is_scaling_complete(&self) -> TorshResult<bool> {
615 if let ScalingState::Scaling { start_time, .. } = *self
617 .scaling_state
618 .read()
619 .expect("lock should not be poisoned")
620 {
621 Ok(start_time.elapsed().unwrap_or(Duration::ZERO) >= self.config.scaling_timeout)
622 } else {
623 Ok(false)
624 }
625 }
626
627 async fn complete_scaling(&self) -> TorshResult<()> {
629 info!("Completing scaling process");
630
631 let expected_workers = if let ScalingState::Scaling {
632 expected_workers, ..
633 } = *self
634 .scaling_state
635 .read()
636 .expect("lock should not be poisoned")
637 {
638 expected_workers
639 } else {
640 return Ok(());
641 };
642
643 {
645 let mut state = self
646 .scaling_state
647 .write()
648 .expect("lock should not be poisoned");
649 *state = ScalingState::Synchronizing {
650 current_workers: *self
651 .current_world_size
652 .read()
653 .expect("lock should not be poisoned"),
654 target_workers: expected_workers,
655 };
656 }
657
658 info!("Transitioning to synchronization phase");
659 Ok(())
660 }
661
662 async fn is_synchronization_complete(&self) -> TorshResult<bool> {
664 Ok(true)
666 }
667
668 async fn complete_synchronization(&self) -> TorshResult<()> {
670 info!("Completing synchronization process");
671
672 let target_workers = if let ScalingState::Synchronizing { target_workers, .. } = *self
673 .scaling_state
674 .read()
675 .expect("lock should not be poisoned")
676 {
677 target_workers
678 } else {
679 return Ok(());
680 };
681
682 {
684 let mut world_size = self
685 .current_world_size
686 .write()
687 .expect("lock should not be poisoned");
688 *world_size = target_workers;
689 }
690
691 {
693 let mut state = self
694 .scaling_state
695 .write()
696 .expect("lock should not be poisoned");
697 *state = ScalingState::Stable;
698 }
699
700 info!(
701 "Elastic scaling completed, new world size: {}",
702 target_workers
703 );
704 Ok(())
705 }
706
707 async fn detect_failed_workers(&self) -> TorshResult<Vec<usize>> {
709 Ok(Vec::new())
712 }
713
714 async fn detect_new_workers(&self) -> TorshResult<Vec<usize>> {
716 Ok(Vec::new())
719 }
720
721 async fn calculate_optimal_workers(
723 &self,
724 _current_workers: usize,
725 ) -> TorshResult<Option<usize>> {
726 Ok(None)
735 }
736
737 pub fn get_scaling_state(&self) -> ScalingState {
739 self.scaling_state
740 .read()
741 .expect("lock should not be poisoned")
742 .clone()
743 }
744
745 pub fn get_world_size(&self) -> usize {
747 *self
748 .current_world_size
749 .read()
750 .expect("lock should not be poisoned")
751 }
752
753 pub async fn scale_to(&self, target_workers: usize) -> TorshResult<()> {
755 let event = ScalingEvent::ManualScale { target_workers };
756 self.initiate_scaling(event).await
757 }
758
759 pub fn get_scaling_history(&self) -> Vec<ScalingEvent> {
761 self.scaling_events
762 .lock()
763 .expect("lock should not be poisoned")
764 .clone()
765 }
766
767 pub fn can_proceed_training(&self) -> bool {
769 matches!(
770 *self
771 .scaling_state
772 .read()
773 .expect("lock should not be poisoned"),
774 ScalingState::Stable
775 )
776 }
777
778 pub fn checkpoint_manager(&self) -> &CheckpointManager {
780 &self.checkpoint_manager
781 }
782}
783
784pub mod checkpoint_utils {
786 use super::*;
787
788 #[allow(dead_code)]
790 pub struct CheckpointParams {
791 pub step: usize,
792 pub epoch: usize,
793 pub model_params: HashMap<String, Parameter>,
794 pub optimizer_state: HashMap<String, Tensor>,
795 pub loss: f32,
796 pub metrics: HashMap<String, f32>,
797 pub world_size: usize,
798 pub rank: usize,
799 }
800
801 pub fn create_checkpoint(params: CheckpointParams) -> TorshResult<TrainingCheckpoint> {
803 let CheckpointParams {
804 step,
805 epoch,
806 model_params,
807 optimizer_state,
808 loss,
809 metrics,
810 world_size,
811 rank,
812 } = params;
813 let mut model_state = HashMap::new();
815 for (name, param) in model_params {
816 let tensor = param.tensor();
817 let tensor_guard = tensor.read();
818 let data = tensor_guard.flatten()?.to_vec()?;
819 model_state.insert(name, data);
820 }
821
822 let mut opt_state = HashMap::new();
824 for (name, tensor) in optimizer_state {
825 let data = tensor.flatten()?.to_vec()?;
826 opt_state.insert(name, data);
827 }
828
829 let distributed_meta = DistributedMetadata {
830 world_size,
831 rank,
832 process_group_info: HashMap::new(),
833 dp_size: world_size, tp_size: 1,
835 pp_size: 1,
836 fsdp_sharding: HashMap::new(),
837 };
838
839 Ok(TrainingCheckpoint {
840 step,
841 epoch,
842 model_state,
843 optimizer_state: opt_state,
844 scheduler_state: HashMap::new(),
845 rng_states: HashMap::new(),
846 loss,
847 metrics,
848 config: HashMap::new(),
849 timestamp: SystemTime::now()
850 .duration_since(UNIX_EPOCH)
851 .expect("time should be after UNIX_EPOCH")
852 .as_secs(),
853 version: "1.0.0".to_string(),
854 distributed_meta,
855 })
856 }
857
858 pub fn restore_model_from_checkpoint(
860 checkpoint: &TrainingCheckpoint,
861 ) -> TorshResult<HashMap<String, Tensor>> {
862 let mut model_params = HashMap::new();
863
864 for (name, data) in &checkpoint.model_state {
865 let shape = vec![data.len()]; let tensor = Tensor::from_vec(data.clone(), &shape)?;
867 model_params.insert(name.clone(), tensor);
868 }
869
870 Ok(model_params)
871 }
872
873 pub fn restore_optimizer_from_checkpoint(
875 checkpoint: &TrainingCheckpoint,
876 ) -> TorshResult<HashMap<String, Tensor>> {
877 let mut optimizer_state = HashMap::new();
878
879 for (name, data) in &checkpoint.optimizer_state {
880 let shape = vec![data.len()]; let tensor = Tensor::from_vec(data.clone(), &shape)?;
882 optimizer_state.insert(name.clone(), tensor);
883 }
884
885 Ok(optimizer_state)
886 }
887}
888
889#[cfg(test)]
890mod tests {
891 use super::*;
892 use tempfile::TempDir;
893
894 #[tokio::test]
895 async fn test_checkpoint_manager() -> TorshResult<()> {
896 let temp_dir = TempDir::new().unwrap();
897 let config = CheckpointConfig {
898 checkpoint_dir: temp_dir.path().to_path_buf(),
899 checkpoint_frequency: 100,
900 max_checkpoints: 3,
901 ..Default::default()
902 };
903
904 let manager = CheckpointManager::new(
905 config,
906 Duration::from_millis(100),
907 Duration::from_millis(200),
908 )?;
909
910 let checkpoint = TrainingCheckpoint {
912 step: 1000,
913 epoch: 10,
914 model_state: {
915 let mut state = HashMap::new();
916 state.insert("weight".to_string(), vec![1.0, 2.0, 3.0]);
917 state
918 },
919 optimizer_state: HashMap::new(),
920 scheduler_state: HashMap::new(),
921 rng_states: HashMap::new(),
922 loss: 0.5,
923 metrics: HashMap::new(),
924 config: HashMap::new(),
925 timestamp: SystemTime::now()
926 .duration_since(UNIX_EPOCH)
927 .expect("time should be after UNIX_EPOCH")
928 .as_secs(),
929 version: "1.0.0".to_string(),
930 distributed_meta: DistributedMetadata {
931 world_size: 4,
932 rank: 0,
933 process_group_info: HashMap::new(),
934 dp_size: 4,
935 tp_size: 1,
936 pp_size: 1,
937 fsdp_sharding: HashMap::new(),
938 },
939 };
940
941 let checkpoint_path = manager.save_checkpoint(checkpoint.clone(), 0).await?;
943 assert!(checkpoint_path.exists());
944
945 let loaded = manager.load_latest_checkpoint().await?;
947 assert!(loaded.is_some());
948 let loaded_checkpoint = loaded.unwrap();
949 assert_eq!(loaded_checkpoint.step, checkpoint.step);
950 assert_eq!(loaded_checkpoint.loss, checkpoint.loss);
951
952 Ok(())
953 }
954
955 #[tokio::test]
956 async fn test_elastic_training_manager() -> TorshResult<()> {
957 let temp_dir = TempDir::new().unwrap();
958 let elastic_config = ElasticConfig {
959 min_workers: 2,
960 max_workers: 8,
961 scaling_timeout: Duration::from_millis(100),
962 ..Default::default()
963 };
964 let checkpoint_config = CheckpointConfig {
965 checkpoint_dir: temp_dir.path().to_path_buf(),
966 ..Default::default()
967 };
968
969 let manager = ElasticTrainingManager::new(
970 elastic_config,
971 checkpoint_config,
972 4, )?;
974
975 assert_eq!(manager.get_world_size(), 4);
976 assert!(manager.can_proceed_training());
977
978 manager.scale_to(6).await?;
980
981 match manager.get_scaling_state() {
983 ScalingState::Scaling {
984 expected_workers, ..
985 } => {
986 assert_eq!(expected_workers, 6);
987 }
988 _ => panic!("Expected scaling state"),
989 }
990
991 Ok(())
992 }
993
994 #[test]
995 fn test_checkpoint_config() {
996 let config = CheckpointConfig::default();
997 assert_eq!(config.checkpoint_frequency, 1000);
998 assert_eq!(config.max_checkpoints, 5);
999 assert!(config.async_save);
1000 }
1001
1002 #[test]
1003 fn test_elastic_config() {
1004 let config = ElasticConfig::default();
1005 assert_eq!(config.min_workers, 1);
1006 assert_eq!(config.max_workers, 64);
1007 assert!(config.enable_elastic_scheduling);
1008 }
1009
1010 #[test]
1011 fn test_scaling_events() {
1012 let event1 = ScalingEvent::WorkerFailure {
1013 failed_ranks: vec![1, 2],
1014 };
1015 let event2 = ScalingEvent::WorkerJoin {
1016 new_ranks: vec![5, 6],
1017 };
1018 let event3 = ScalingEvent::ManualScale { target_workers: 8 };
1019
1020 assert_ne!(event1, event2);
1021 assert_ne!(event2, event3);
1022 }
1023}