1use crate::error::{NeuralError, Result};
41use crate::serialization::safetensors::{SafeTensorsReader, SafeTensorsWriter};
42use crate::serialization::traits::{ModelMetadata, NamedParameters};
43use scirs2_core::ndarray::IxDyn;
44use scirs2_core::numeric::{Float, ToPrimitive};
45use serde::{Deserialize, Serialize};
46use std::collections::HashMap;
47use std::fmt::Debug;
48use std::fs;
49use std::marker::PhantomData;
50use std::path::{Path, PathBuf};
51
52#[derive(Debug, thiserror::Error)]
58pub enum CheckpointError {
59 #[error("Checkpoint I/O error: {0}")]
61 Io(#[from] std::io::Error),
62
63 #[error("Checkpoint serialization error: {0}")]
65 Serialization(String),
66
67 #[error("No checkpoint found in directory: {0}")]
69 NotFound(String),
70
71 #[error("Checkpoint version mismatch: expected {expected}, found {found}")]
73 VersionMismatch { expected: String, found: String },
74
75 #[error("Invalid checkpoint configuration: {0}")]
77 InvalidConfig(String),
78}
79
80impl From<CheckpointError> for NeuralError {
81 fn from(e: CheckpointError) -> Self {
82 NeuralError::IOError(e.to_string())
83 }
84}
85
86#[derive(Debug, Clone)]
92pub struct CheckpointConfig {
93 pub save_dir: PathBuf,
95 pub save_every: usize,
97 pub max_checkpoints: usize,
99 pub save_best: bool,
101 pub monitor_metric: String,
104 pub minimize_metric: bool,
107}
108
109impl Default for CheckpointConfig {
110 fn default() -> Self {
111 Self {
112 save_dir: PathBuf::from("checkpoints"),
113 save_every: 1,
114 max_checkpoints: 5,
115 save_best: true,
116 monitor_metric: "val_loss".to_string(),
117 minimize_metric: true,
118 }
119 }
120}
121
122impl CheckpointConfig {
123 pub fn validate(&self) -> std::result::Result<(), CheckpointError> {
125 if self.monitor_metric.is_empty() {
126 return Err(CheckpointError::InvalidConfig(
127 "monitor_metric must not be empty".to_string(),
128 ));
129 }
130 Ok(())
131 }
132}
133
134#[derive(Debug, Clone, Serialize, Deserialize)]
144pub struct OptimizerCheckpointState {
145 pub optimizer_type: String,
147 pub learning_rate: f64,
149 pub step: usize,
151 pub beta1: Option<f64>,
153 pub beta2: Option<f64>,
155 pub epsilon: Option<f64>,
157 pub weight_decay: f64,
159 pub first_moments: HashMap<String, Vec<f64>>,
161 pub second_moments: HashMap<String, Vec<f64>>,
163 pub param_shapes: HashMap<String, Vec<usize>>,
165}
166
167impl Default for OptimizerCheckpointState {
168 fn default() -> Self {
169 Self {
170 optimizer_type: "Unknown".to_string(),
171 learning_rate: 0.001,
172 step: 0,
173 beta1: None,
174 beta2: None,
175 epsilon: None,
176 weight_decay: 0.0,
177 first_moments: HashMap::new(),
178 second_moments: HashMap::new(),
179 param_shapes: HashMap::new(),
180 }
181 }
182}
183
184impl OptimizerCheckpointState {
185 pub fn adam(learning_rate: f64, beta1: f64, beta2: f64, epsilon: f64) -> Self {
187 Self {
188 optimizer_type: "Adam".to_string(),
189 learning_rate,
190 step: 0,
191 beta1: Some(beta1),
192 beta2: Some(beta2),
193 epsilon: Some(epsilon),
194 weight_decay: 0.0,
195 first_moments: HashMap::new(),
196 second_moments: HashMap::new(),
197 param_shapes: HashMap::new(),
198 }
199 }
200
201 pub fn sgd(learning_rate: f64, momentum: f64, weight_decay: f64) -> Self {
203 Self {
204 optimizer_type: "SGD".to_string(),
205 learning_rate,
206 step: 0,
207 beta1: Some(momentum),
208 beta2: None,
209 epsilon: None,
210 weight_decay,
211 first_moments: HashMap::new(),
212 second_moments: HashMap::new(),
213 param_shapes: HashMap::new(),
214 }
215 }
216
217 pub fn has_moments(&self) -> bool {
219 !self.first_moments.is_empty() || !self.second_moments.is_empty()
220 }
221}
222
223#[derive(Debug, Clone, Serialize, Deserialize)]
229pub struct LrSchedulerState {
230 pub scheduler_type: String,
232 pub scheduler_step: usize,
234 pub current_lr: f64,
236 pub base_lr: f64,
238 pub patience_counter: usize,
240 pub best_metric: Option<f64>,
242 pub extra_params: HashMap<String, f64>,
244}
245
246impl Default for LrSchedulerState {
247 fn default() -> Self {
248 Self {
249 scheduler_type: "Identity".to_string(),
250 scheduler_step: 0,
251 current_lr: 0.001,
252 base_lr: 0.001,
253 patience_counter: 0,
254 best_metric: None,
255 extra_params: HashMap::new(),
256 }
257 }
258}
259
260impl LrSchedulerState {
261 pub fn cosine_annealing(base_lr: f64, t_max: usize) -> Self {
263 let mut extra = HashMap::new();
264 extra.insert("t_max".to_string(), t_max as f64);
265 Self {
266 scheduler_type: "CosineAnnealing".to_string(),
267 current_lr: base_lr,
268 base_lr,
269 extra_params: extra,
270 ..Default::default()
271 }
272 }
273
274 pub fn step_lr(base_lr: f64, step_size: usize, gamma: f64) -> Self {
276 let mut extra = HashMap::new();
277 extra.insert("step_size".to_string(), step_size as f64);
278 extra.insert("gamma".to_string(), gamma);
279 Self {
280 scheduler_type: "StepLR".to_string(),
281 current_lr: base_lr,
282 base_lr,
283 extra_params: extra,
284 ..Default::default()
285 }
286 }
287}
288
289#[derive(Debug, Clone, Serialize, Deserialize)]
300pub struct TrainingCheckpoint {
301 pub epoch: usize,
303 pub step: usize,
305 pub best_metric: Option<f64>,
307 pub metrics_history: Vec<HashMap<String, f64>>,
309 pub optimizer_state: OptimizerCheckpointState,
311 pub lr_scheduler_state: Option<LrSchedulerState>,
313 pub format_version: String,
315 pub architecture: String,
317 pub timestamp: String,
319 pub total_epochs: Option<usize>,
321 pub training_completed: bool,
323 pub random_seed: Option<u64>,
325}
326
327impl Default for TrainingCheckpoint {
328 fn default() -> Self {
329 Self {
330 epoch: 0,
331 step: 0,
332 best_metric: None,
333 metrics_history: Vec::new(),
334 optimizer_state: OptimizerCheckpointState::default(),
335 lr_scheduler_state: None,
336 format_version: "0.3.0".to_string(),
337 architecture: "Unknown".to_string(),
338 timestamp: simple_timestamp(),
339 total_epochs: None,
340 training_completed: false,
341 random_seed: None,
342 }
343 }
344}
345
346impl TrainingCheckpoint {
347 pub fn new(epoch: usize, step: usize, architecture: &str) -> Self {
349 Self {
350 epoch,
351 step,
352 architecture: architecture.to_string(),
353 timestamp: simple_timestamp(),
354 ..Default::default()
355 }
356 }
357
358 pub fn latest_metric(&self, name: &str) -> Option<f64> {
360 self.metrics_history
361 .last()
362 .and_then(|m| m.get(name).copied())
363 }
364
365 pub fn mark_completed(mut self) -> Self {
367 self.training_completed = true;
368 self
369 }
370}
371
372pub struct CheckpointManager<F: Float + Debug> {
381 config: CheckpointConfig,
383 saved_checkpoints: Vec<PathBuf>,
385 best_metric_value: Option<f64>,
387 _phantom: PhantomData<F>,
389}
390
391impl<F: Float + Debug + ToPrimitive + 'static> CheckpointManager<F> {
392 pub fn new(config: CheckpointConfig) -> Self {
394 Self {
395 config,
396 saved_checkpoints: Vec::new(),
397 best_metric_value: None,
398 _phantom: PhantomData,
399 }
400 }
401
402 pub fn save(
419 &mut self,
420 checkpoint: &TrainingCheckpoint,
421 model_params: &NamedParameters,
422 epoch: usize,
423 metrics: &HashMap<String, F>,
424 ) -> std::result::Result<PathBuf, CheckpointError> {
425 self.config.validate()?;
426
427 if self.config.save_every > 0 && !epoch.is_multiple_of(self.config.save_every) {
429 if self.config.save_best {
431 let _ = self.maybe_save_best(checkpoint, model_params, metrics)?;
432 }
433 return Ok(self.config.save_dir.clone());
434 }
435
436 fs::create_dir_all(&self.config.save_dir)?;
437
438 let dir_name = format!("epoch_{:04}.ckpt", epoch);
439 let ckpt_dir = self.config.save_dir.join(&dir_name);
440
441 self.write_checkpoint_to_dir(checkpoint, model_params, &ckpt_dir)?;
442
443 self.saved_checkpoints.push(ckpt_dir.clone());
445 self.cleanup_old_checkpoints()?;
446
447 if self.config.save_best {
449 let _ = self.maybe_save_best(checkpoint, model_params, metrics)?;
450 }
451
452 Ok(ckpt_dir)
453 }
454
455 pub fn load(
463 path: &Path,
464 ) -> std::result::Result<(TrainingCheckpoint, NamedParameters), CheckpointError> {
465 if !path.exists() {
466 return Err(CheckpointError::NotFound(path.display().to_string()));
467 }
468
469 let meta_path = path.join("checkpoint_meta.json");
471 if !meta_path.exists() {
472 return Err(CheckpointError::NotFound(format!(
473 "checkpoint_meta.json not found in {}",
474 path.display()
475 )));
476 }
477 let meta_bytes = fs::read(&meta_path)?;
478 let checkpoint: TrainingCheckpoint = serde_json::from_slice(&meta_bytes)
479 .map_err(|e| CheckpointError::Serialization(e.to_string()))?;
480
481 let model_path = path.join("model.safetensors");
483 if !model_path.exists() {
484 return Err(CheckpointError::NotFound(format!(
485 "model.safetensors not found in {}",
486 path.display()
487 )));
488 }
489 let reader = SafeTensorsReader::from_file(&model_path)
490 .map_err(|e| CheckpointError::Serialization(e.to_string()))?;
491 let model_params = reader
492 .to_named_parameters()
493 .map_err(|e| CheckpointError::Serialization(e.to_string()))?;
494
495 Ok((checkpoint, model_params))
496 }
497
498 pub fn load_latest(
502 save_dir: &Path,
503 ) -> std::result::Result<Option<(TrainingCheckpoint, NamedParameters)>, CheckpointError> {
504 let checkpoints = Self::list_checkpoints(save_dir)?;
505 match checkpoints.last() {
506 None => Ok(None),
507 Some(path) => {
508 let result = Self::load(path)?;
509 Ok(Some(result))
510 }
511 }
512 }
513
514 pub fn load_best(
519 save_dir: &Path,
520 ) -> std::result::Result<Option<(TrainingCheckpoint, NamedParameters)>, CheckpointError> {
521 let best_dir = save_dir.join("best.ckpt");
522 if !best_dir.exists() {
523 return Ok(None);
524 }
525 let result = Self::load(&best_dir)?;
526 Ok(Some(result))
527 }
528
529 pub fn list_checkpoints(save_dir: &Path) -> std::result::Result<Vec<PathBuf>, CheckpointError> {
534 if !save_dir.exists() {
535 return Ok(Vec::new());
536 }
537
538 let entries = fs::read_dir(save_dir)?;
539 let mut checkpoints: Vec<(usize, PathBuf)> = Vec::new();
540
541 for entry in entries {
542 let entry = entry?;
543 let path = entry.path();
544 if !path.is_dir() {
545 continue;
546 }
547 let file_name = match path.file_name().and_then(|n| n.to_str()) {
548 Some(n) => n.to_owned(),
549 None => continue,
550 };
551 if file_name.starts_with("epoch_") && file_name.ends_with(".ckpt") {
553 let epoch_part = file_name
554 .trim_start_matches("epoch_")
555 .trim_end_matches(".ckpt");
556 if let Ok(epoch) = epoch_part.parse::<usize>() {
557 checkpoints.push((epoch, path));
558 }
559 }
560 }
561
562 checkpoints.sort_by_key(|(epoch, _)| *epoch);
564 Ok(checkpoints.into_iter().map(|(_, p)| p).collect())
565 }
566
567 fn cleanup_old_checkpoints(&mut self) -> std::result::Result<(), CheckpointError> {
569 if self.config.max_checkpoints == 0 {
570 return Ok(());
571 }
572
573 while self.saved_checkpoints.len() > self.config.max_checkpoints {
574 let oldest = self.saved_checkpoints.remove(0);
575 if oldest.exists() {
576 fs::remove_dir_all(&oldest)?;
577 }
578 }
579 Ok(())
580 }
581
582 fn maybe_save_best(
584 &mut self,
585 checkpoint: &TrainingCheckpoint,
586 model_params: &NamedParameters,
587 metrics: &HashMap<String, F>,
588 ) -> std::result::Result<bool, CheckpointError> {
589 let metric_value = match metrics.get(&self.config.monitor_metric) {
590 Some(v) => match v.to_f64() {
591 Some(f) => f,
592 None => return Ok(false),
593 },
594 None => return Ok(false),
595 };
596
597 let is_better = match self.best_metric_value {
598 None => true,
599 Some(best) => {
600 if self.config.minimize_metric {
601 metric_value < best
602 } else {
603 metric_value > best
604 }
605 }
606 };
607
608 if is_better {
609 self.best_metric_value = Some(metric_value);
610 fs::create_dir_all(&self.config.save_dir)?;
611 let best_dir = self.config.save_dir.join("best.ckpt");
612 self.write_checkpoint_to_dir(checkpoint, model_params, &best_dir)?;
613 Ok(true)
614 } else {
615 Ok(false)
616 }
617 }
618
619 fn write_checkpoint_to_dir(
621 &self,
622 checkpoint: &TrainingCheckpoint,
623 model_params: &NamedParameters,
624 dir: &Path,
625 ) -> std::result::Result<(), CheckpointError> {
626 fs::create_dir_all(dir)?;
627
628 let meta_json = serde_json::to_string_pretty(checkpoint)
630 .map_err(|e| CheckpointError::Serialization(e.to_string()))?;
631 fs::write(dir.join("checkpoint_meta.json"), meta_json.as_bytes())?;
632
633 let model_path = dir.join("model.safetensors");
635 let meta = ModelMetadata::new(
636 &checkpoint.architecture,
637 "f64",
638 model_params.total_parameters(),
639 )
640 .with_extra("epoch", &checkpoint.epoch.to_string())
641 .with_extra("format_version", &checkpoint.format_version);
642
643 let mut writer = SafeTensorsWriter::new();
644 writer.add_model_metadata(&meta);
645 writer
646 .add_named_parameters(model_params)
647 .map_err(|e| CheckpointError::Serialization(e.to_string()))?;
648 writer
649 .write_to_file(&model_path)
650 .map_err(|e| CheckpointError::Serialization(e.to_string()))?;
651
652 let opt_json = serde_json::to_string_pretty(&checkpoint.optimizer_state)
654 .map_err(|e| CheckpointError::Serialization(e.to_string()))?;
655 fs::write(dir.join("optimizer_state.json"), opt_json.as_bytes())?;
656
657 Ok(())
658 }
659
660 pub fn best_metric_value(&self) -> Option<f64> {
662 self.best_metric_value
663 }
664
665 pub fn config(&self) -> &CheckpointConfig {
667 &self.config
668 }
669
670 pub fn saved_checkpoint_paths(&self) -> &[PathBuf] {
672 &self.saved_checkpoints
673 }
674}
675
676#[derive(Debug, Clone, Serialize, Deserialize)]
685pub struct CheckpointMetadata {
686 pub epoch: usize,
688 pub global_step: usize,
690 pub learning_rate: f64,
692 pub train_loss: Option<f64>,
694 pub val_loss: Option<f64>,
696 pub best_val_loss: Option<f64>,
698 pub metrics: HashMap<String, f64>,
700 pub optimizer_state: OptimizerStateMetadata,
702 pub architecture: String,
704 pub model_version: String,
706 pub timestamp: String,
708 pub training_completed: bool,
710 pub total_epochs: Option<usize>,
712 pub random_seed: Option<u64>,
714}
715
716impl Default for CheckpointMetadata {
717 fn default() -> Self {
718 Self {
719 epoch: 0,
720 global_step: 0,
721 learning_rate: 0.001,
722 train_loss: None,
723 val_loss: None,
724 best_val_loss: None,
725 metrics: HashMap::new(),
726 optimizer_state: OptimizerStateMetadata::default(),
727 architecture: "Unknown".to_string(),
728 model_version: "0.1.0".to_string(),
729 timestamp: simple_timestamp(),
730 training_completed: false,
731 total_epochs: None,
732 random_seed: None,
733 }
734 }
735}
736
737impl CheckpointMetadata {
738 pub fn new(architecture: &str, epoch: usize, learning_rate: f64) -> Self {
740 Self {
741 architecture: architecture.to_string(),
742 epoch,
743 learning_rate,
744 timestamp: simple_timestamp(),
745 ..Default::default()
746 }
747 }
748
749 pub fn with_train_loss(mut self, loss: f64) -> Self {
751 self.train_loss = Some(loss);
752 self
753 }
754
755 pub fn with_val_loss(mut self, loss: f64) -> Self {
757 self.val_loss = Some(loss);
758 self
759 }
760
761 pub fn with_best_val_loss(mut self, loss: f64) -> Self {
763 self.best_val_loss = Some(loss);
764 self
765 }
766
767 pub fn with_metric(mut self, name: &str, value: f64) -> Self {
769 self.metrics.insert(name.to_string(), value);
770 self
771 }
772
773 pub fn with_total_epochs(mut self, total: usize) -> Self {
775 self.total_epochs = Some(total);
776 self
777 }
778
779 pub fn with_global_step(mut self, step: usize) -> Self {
781 self.global_step = step;
782 self
783 }
784
785 pub fn mark_completed(mut self) -> Self {
787 self.training_completed = true;
788 self
789 }
790}
791
792#[derive(Debug, Clone, Serialize, Deserialize)]
794pub struct OptimizerStateMetadata {
795 pub optimizer_type: String,
797 pub num_param_groups: usize,
799 pub param_groups: Vec<ParamGroupState>,
801}
802
803impl Default for OptimizerStateMetadata {
804 fn default() -> Self {
805 Self {
806 optimizer_type: "Unknown".to_string(),
807 num_param_groups: 0,
808 param_groups: Vec::new(),
809 }
810 }
811}
812
813impl OptimizerStateMetadata {
814 pub fn new(optimizer_type: &str) -> Self {
816 Self {
817 optimizer_type: optimizer_type.to_string(),
818 num_param_groups: 1,
819 param_groups: vec![ParamGroupState::default()],
820 }
821 }
822}
823
824#[derive(Debug, Clone, Serialize, Deserialize)]
826pub struct ParamGroupState {
827 pub learning_rate: f64,
829 pub weight_decay: f64,
831 pub momentum: Option<f64>,
833 pub beta1: Option<f64>,
835 pub beta2: Option<f64>,
837 pub epsilon: Option<f64>,
839 pub step_count: usize,
841}
842
843impl Default for ParamGroupState {
844 fn default() -> Self {
845 Self {
846 learning_rate: 0.001,
847 weight_decay: 0.0,
848 momentum: None,
849 beta1: None,
850 beta2: None,
851 epsilon: None,
852 step_count: 0,
853 }
854 }
855}
856
857impl ParamGroupState {
858 pub fn adam(learning_rate: f64, beta1: f64, beta2: f64, epsilon: f64) -> Self {
860 Self {
861 learning_rate,
862 weight_decay: 0.0,
863 momentum: None,
864 beta1: Some(beta1),
865 beta2: Some(beta2),
866 epsilon: Some(epsilon),
867 step_count: 0,
868 }
869 }
870
871 pub fn sgd(learning_rate: f64, momentum: f64, weight_decay: f64) -> Self {
873 Self {
874 learning_rate,
875 weight_decay,
876 momentum: Some(momentum),
877 beta1: None,
878 beta2: None,
879 epsilon: None,
880 step_count: 0,
881 }
882 }
883}
884
885pub fn save_checkpoint(
896 checkpoint_dir: &Path,
897 model_params: &NamedParameters,
898 metadata: &CheckpointMetadata,
899 optimizer_moments: Option<&NamedParameters>,
900) -> Result<()> {
901 fs::create_dir_all(checkpoint_dir)
902 .map_err(|e| NeuralError::IOError(format!("Cannot create checkpoint directory: {e}")))?;
903
904 let model_path = checkpoint_dir.join("model.safetensors");
906 let model_metadata = ModelMetadata::new(
907 &metadata.architecture,
908 "f64",
909 model_params.total_parameters(),
910 )
911 .with_extra("epoch", &metadata.epoch.to_string())
912 .with_extra("checkpoint", "true");
913
914 let mut writer = SafeTensorsWriter::new();
915 writer.add_model_metadata(&model_metadata);
916 writer.add_named_parameters(model_params)?;
917 writer.write_to_file(&model_path)?;
918
919 let meta_path = checkpoint_dir.join("checkpoint_meta.json");
921 let meta_json = serde_json::to_string_pretty(metadata)
922 .map_err(|e| NeuralError::SerializationError(format!("Cannot serialize metadata: {e}")))?;
923 fs::write(&meta_path, meta_json)
924 .map_err(|e| NeuralError::IOError(format!("Cannot write metadata: {e}")))?;
925
926 if let Some(moments) = optimizer_moments {
928 if !moments.is_empty() {
929 let optimizer_path = checkpoint_dir.join("optimizer_state.safetensors");
930 let opt_metadata = ModelMetadata::new("optimizer", "f64", moments.total_parameters());
931 let mut opt_writer = SafeTensorsWriter::new();
932 opt_writer.add_model_metadata(&opt_metadata);
933 opt_writer.add_named_parameters(moments)?;
934 opt_writer.write_to_file(&optimizer_path)?;
935 }
936 }
937
938 Ok(())
939}
940
941pub fn load_checkpoint(
945 checkpoint_dir: &Path,
946) -> Result<(NamedParameters, CheckpointMetadata, Option<NamedParameters>)> {
947 if !checkpoint_dir.exists() {
948 return Err(NeuralError::IOError(format!(
949 "Checkpoint directory does not exist: {}",
950 checkpoint_dir.display()
951 )));
952 }
953
954 let model_path = checkpoint_dir.join("model.safetensors");
956 if !model_path.exists() {
957 return Err(NeuralError::IOError(format!(
958 "Model weights not found at: {}",
959 model_path.display()
960 )));
961 }
962 let reader = SafeTensorsReader::from_file(&model_path)?;
963 let model_params = reader.to_named_parameters()?;
964
965 let meta_path = checkpoint_dir.join("checkpoint_meta.json");
967 if !meta_path.exists() {
968 return Err(NeuralError::IOError(format!(
969 "Checkpoint metadata not found at: {}",
970 meta_path.display()
971 )));
972 }
973 let meta_json = fs::read_to_string(&meta_path)
974 .map_err(|e| NeuralError::IOError(format!("Cannot read metadata: {e}")))?;
975 let metadata: CheckpointMetadata = serde_json::from_str(&meta_json)
976 .map_err(|e| NeuralError::DeserializationError(format!("Invalid metadata: {e}")))?;
977
978 let optimizer_path = checkpoint_dir.join("optimizer_state.safetensors");
980 let optimizer_moments = if optimizer_path.exists() {
981 let opt_reader = SafeTensorsReader::from_file(&optimizer_path)?;
982 Some(opt_reader.to_named_parameters()?)
983 } else {
984 None
985 };
986
987 Ok((model_params, metadata, optimizer_moments))
988}
989
990pub fn list_checkpoints(base_dir: &Path) -> Result<Vec<(usize, PathBuf)>> {
994 if !base_dir.exists() {
995 return Ok(Vec::new());
996 }
997
998 let mut checkpoints = Vec::new();
999
1000 let entries = fs::read_dir(base_dir)
1001 .map_err(|e| NeuralError::IOError(format!("Cannot read directory: {e}")))?;
1002
1003 for entry in entries {
1004 let entry = entry.map_err(|e| NeuralError::IOError(format!("Cannot read entry: {e}")))?;
1005 let path = entry.path();
1006
1007 if path.is_dir() {
1008 let meta_path = path.join("checkpoint_meta.json");
1009 if meta_path.exists() {
1010 if let Ok(meta_json) = fs::read_to_string(&meta_path) {
1011 if let Ok(meta) = serde_json::from_str::<CheckpointMetadata>(&meta_json) {
1012 checkpoints.push((meta.epoch, path));
1013 }
1014 }
1015 }
1016 }
1017 }
1018
1019 checkpoints.sort_by_key(|(epoch, _)| *epoch);
1020 Ok(checkpoints)
1021}
1022
1023pub fn latest_checkpoint(base_dir: &Path) -> Result<Option<PathBuf>> {
1025 let checkpoints = list_checkpoints(base_dir)?;
1026 Ok(checkpoints.last().map(|(_, path)| path.clone()))
1027}
1028
1029pub fn best_checkpoint(base_dir: &Path) -> Result<Option<PathBuf>> {
1031 let checkpoints = list_checkpoints(base_dir)?;
1032
1033 let mut best: Option<(f64, PathBuf)> = None;
1034
1035 for (_, path) in &checkpoints {
1036 let meta_path = path.join("checkpoint_meta.json");
1037 if let Ok(meta_json) = fs::read_to_string(&meta_path) {
1038 if let Ok(meta) = serde_json::from_str::<CheckpointMetadata>(&meta_json) {
1039 if let Some(val_loss) = meta.val_loss {
1040 match &best {
1041 None => best = Some((val_loss, path.clone())),
1042 Some((best_loss, _)) => {
1043 if val_loss < *best_loss {
1044 best = Some((val_loss, path.clone()));
1045 }
1046 }
1047 }
1048 }
1049 }
1050 }
1051 }
1052
1053 Ok(best.map(|(_, path)| path))
1054}
1055
1056pub fn checkpoint_dir_name(epoch: usize) -> String {
1058 format!("checkpoint_epoch_{epoch:04}")
1059}
1060
1061fn simple_timestamp() -> String {
1067 let now = std::time::SystemTime::now();
1068 let duration = now
1069 .duration_since(std::time::UNIX_EPOCH)
1070 .unwrap_or_default();
1071 let secs = duration.as_secs();
1072
1073 let days = secs / 86400;
1074 let remaining = secs % 86400;
1075 let hours = remaining / 3600;
1076 let minutes = (remaining % 3600) / 60;
1077 let seconds = remaining % 60;
1078
1079 let years = 1970 + (days / 365);
1081 let day_in_year = days % 365;
1082 let month = (day_in_year / 30) + 1;
1083 let day = (day_in_year % 30) + 1;
1084
1085 format!("{years:04}-{month:02}-{day:02}T{hours:02}:{minutes:02}:{seconds:02}Z")
1086}
1087
1088#[cfg(test)]
1093mod tests {
1094 use super::*;
1095
1096 #[test]
1097 fn test_checkpoint_metadata_default() {
1098 let meta = CheckpointMetadata::default();
1099 assert_eq!(meta.epoch, 0);
1100 assert_eq!(meta.global_step, 0);
1101 assert!(!meta.training_completed);
1102 assert!(meta.train_loss.is_none());
1103 assert!(meta.val_loss.is_none());
1104 }
1105
1106 #[test]
1107 fn test_checkpoint_metadata_builder() {
1108 let meta = CheckpointMetadata::new("ResNet", 5, 0.001)
1109 .with_train_loss(0.25)
1110 .with_val_loss(0.30)
1111 .with_best_val_loss(0.28)
1112 .with_metric("accuracy", 0.92)
1113 .with_total_epochs(100)
1114 .with_global_step(5000);
1115
1116 assert_eq!(meta.architecture, "ResNet");
1117 assert_eq!(meta.epoch, 5);
1118 assert_eq!(meta.learning_rate, 0.001);
1119 assert_eq!(meta.train_loss, Some(0.25));
1120 assert_eq!(meta.val_loss, Some(0.30));
1121 assert_eq!(meta.best_val_loss, Some(0.28));
1122 assert_eq!(meta.metrics.get("accuracy"), Some(&0.92));
1123 assert_eq!(meta.total_epochs, Some(100));
1124 assert_eq!(meta.global_step, 5000);
1125 }
1126
1127 #[test]
1128 fn test_checkpoint_metadata_serialization() -> Result<()> {
1129 let meta = CheckpointMetadata::new("BERT", 10, 0.0001)
1130 .with_train_loss(0.15)
1131 .with_val_loss(0.20);
1132
1133 let json = serde_json::to_string_pretty(&meta)
1134 .map_err(|e| NeuralError::SerializationError(e.to_string()))?;
1135
1136 let restored: CheckpointMetadata = serde_json::from_str(&json)
1137 .map_err(|e| NeuralError::DeserializationError(e.to_string()))?;
1138
1139 assert_eq!(restored.architecture, "BERT");
1140 assert_eq!(restored.epoch, 10);
1141 assert_eq!(restored.train_loss, Some(0.15));
1142
1143 Ok(())
1144 }
1145
1146 #[test]
1147 fn test_optimizer_state_metadata() {
1148 let state = OptimizerStateMetadata::new("Adam");
1149 assert_eq!(state.optimizer_type, "Adam");
1150 assert_eq!(state.num_param_groups, 1);
1151 }
1152
1153 #[test]
1154 fn test_param_group_state_adam() {
1155 let pg = ParamGroupState::adam(0.001, 0.9, 0.999, 1e-8);
1156 assert_eq!(pg.learning_rate, 0.001);
1157 assert_eq!(pg.beta1, Some(0.9));
1158 assert_eq!(pg.beta2, Some(0.999));
1159 assert_eq!(pg.epsilon, Some(1e-8));
1160 assert!(pg.momentum.is_none());
1161 }
1162
1163 #[test]
1164 fn test_param_group_state_sgd() {
1165 let pg = ParamGroupState::sgd(0.01, 0.9, 0.0001);
1166 assert_eq!(pg.learning_rate, 0.01);
1167 assert_eq!(pg.momentum, Some(0.9));
1168 assert_eq!(pg.weight_decay, 0.0001);
1169 assert!(pg.beta1.is_none());
1170 }
1171
1172 #[test]
1173 fn test_save_load_checkpoint() -> Result<()> {
1174 let test_dir = std::env::temp_dir().join("scirs2_checkpoint_test");
1175 let checkpoint_dir = test_dir.join("checkpoint_epoch_0005");
1176
1177 let _ = fs::remove_dir_all(&test_dir);
1179
1180 let mut params = NamedParameters::new();
1182 params.add("layer.0.weight", vec![1.0, 2.0, 3.0, 4.0], vec![2, 2]);
1183 params.add("layer.0.bias", vec![0.1, 0.2], vec![2]);
1184 params.add("layer.1.weight", vec![5.0, 6.0], vec![1, 2]);
1185 params.add("layer.1.bias", vec![0.3], vec![1]);
1186
1187 let mut moments = NamedParameters::new();
1189 moments.add("layer.0.weight.m", vec![0.01, 0.02, 0.03, 0.04], vec![2, 2]);
1190 moments.add(
1191 "layer.0.weight.v",
1192 vec![0.001, 0.002, 0.003, 0.004],
1193 vec![2, 2],
1194 );
1195
1196 let meta = CheckpointMetadata::new("TestModel", 5, 0.001)
1198 .with_train_loss(0.25)
1199 .with_val_loss(0.30)
1200 .with_total_epochs(100);
1201
1202 save_checkpoint(&checkpoint_dir, ¶ms, &meta, Some(&moments))?;
1204
1205 assert!(checkpoint_dir.join("model.safetensors").exists());
1207 assert!(checkpoint_dir.join("checkpoint_meta.json").exists());
1208 assert!(checkpoint_dir.join("optimizer_state.safetensors").exists());
1209
1210 let (loaded_params, loaded_meta, loaded_moments) = load_checkpoint(&checkpoint_dir)?;
1212
1213 assert_eq!(loaded_params.len(), 4);
1215 assert_eq!(loaded_params.total_parameters(), 9); let (_, values, shape) = loaded_params
1218 .get("layer.0.weight")
1219 .ok_or_else(|| NeuralError::DeserializationError("not found".to_string()))?;
1220 assert_eq!(values, &[1.0, 2.0, 3.0, 4.0]);
1221 assert_eq!(shape, &[2, 2]);
1222
1223 assert_eq!(loaded_meta.architecture, "TestModel");
1225 assert_eq!(loaded_meta.epoch, 5);
1226 assert_eq!(loaded_meta.learning_rate, 0.001);
1227 assert_eq!(loaded_meta.train_loss, Some(0.25));
1228 assert_eq!(loaded_meta.val_loss, Some(0.30));
1229 assert_eq!(loaded_meta.total_epochs, Some(100));
1230
1231 assert!(loaded_moments.is_some());
1233 let moments = loaded_moments.expect("should have moments");
1234 assert_eq!(moments.len(), 2);
1235
1236 let _ = fs::remove_dir_all(&test_dir);
1238 Ok(())
1239 }
1240
1241 #[test]
1242 fn test_save_checkpoint_without_optimizer() -> Result<()> {
1243 let test_dir = std::env::temp_dir().join("scirs2_checkpoint_no_opt");
1244 let checkpoint_dir = test_dir.join("checkpoint_epoch_0001");
1245
1246 let _ = fs::remove_dir_all(&test_dir);
1247
1248 let mut params = NamedParameters::new();
1249 params.add("w", vec![1.0, 2.0], vec![2]);
1250
1251 let meta = CheckpointMetadata::new("Simple", 1, 0.01);
1252
1253 save_checkpoint(&checkpoint_dir, ¶ms, &meta, None)?;
1254
1255 let (loaded_params, loaded_meta, loaded_moments) = load_checkpoint(&checkpoint_dir)?;
1256
1257 assert_eq!(loaded_params.len(), 1);
1258 assert_eq!(loaded_meta.epoch, 1);
1259 assert!(loaded_moments.is_none());
1260
1261 let _ = fs::remove_dir_all(&test_dir);
1262 Ok(())
1263 }
1264
1265 #[test]
1266 fn test_list_checkpoints() -> Result<()> {
1267 let test_dir = std::env::temp_dir().join("scirs2_list_checkpoints");
1268 let _ = fs::remove_dir_all(&test_dir);
1269
1270 for epoch in [1, 5, 10] {
1271 let dir_name = checkpoint_dir_name(epoch);
1272 let dir = test_dir.join(&dir_name);
1273
1274 let mut params = NamedParameters::new();
1275 params.add("w", vec![1.0], vec![1]);
1276
1277 let meta = CheckpointMetadata::new("Test", epoch, 0.001);
1278 save_checkpoint(&dir, ¶ms, &meta, None)?;
1279 }
1280
1281 let checkpoints = list_checkpoints(&test_dir)?;
1282 assert_eq!(checkpoints.len(), 3);
1283 assert_eq!(checkpoints[0].0, 1);
1284 assert_eq!(checkpoints[1].0, 5);
1285 assert_eq!(checkpoints[2].0, 10);
1286
1287 let latest = latest_checkpoint(&test_dir)?;
1289 assert!(latest.is_some());
1290
1291 let _ = fs::remove_dir_all(&test_dir);
1292 Ok(())
1293 }
1294
1295 #[test]
1296 fn test_best_checkpoint() -> Result<()> {
1297 let test_dir = std::env::temp_dir().join("scirs2_best_checkpoint");
1298 let _ = fs::remove_dir_all(&test_dir);
1299
1300 let losses = [(1, 0.50), (2, 0.35), (3, 0.30), (4, 0.32), (5, 0.28)];
1301
1302 for (epoch, val_loss) in &losses {
1303 let dir = test_dir.join(checkpoint_dir_name(*epoch));
1304 let mut params = NamedParameters::new();
1305 params.add("w", vec![1.0], vec![1]);
1306
1307 let meta = CheckpointMetadata::new("Test", *epoch, 0.001).with_val_loss(*val_loss);
1308
1309 save_checkpoint(&dir, ¶ms, &meta, None)?;
1310 }
1311
1312 let best = best_checkpoint(&test_dir)?;
1313 assert!(best.is_some());
1314
1315 let (_, meta, _) = load_checkpoint(&best.expect("best should exist"))?;
1317 assert_eq!(meta.epoch, 5);
1318 assert_eq!(meta.val_loss, Some(0.28));
1319
1320 let _ = fs::remove_dir_all(&test_dir);
1321 Ok(())
1322 }
1323
1324 #[test]
1325 fn test_checkpoint_dir_name() {
1326 assert_eq!(checkpoint_dir_name(0), "checkpoint_epoch_0000");
1327 assert_eq!(checkpoint_dir_name(1), "checkpoint_epoch_0001");
1328 assert_eq!(checkpoint_dir_name(42), "checkpoint_epoch_0042");
1329 assert_eq!(checkpoint_dir_name(999), "checkpoint_epoch_0999");
1330 assert_eq!(checkpoint_dir_name(10000), "checkpoint_epoch_10000");
1331 }
1332
1333 #[test]
1334 fn test_load_nonexistent_checkpoint() {
1335 let result = load_checkpoint(Path::new("/tmp/nonexistent_checkpoint_xyz"));
1336 assert!(result.is_err());
1337 }
1338
1339 #[test]
1340 fn test_list_empty_directory() -> Result<()> {
1341 let result = list_checkpoints(Path::new("/tmp/nonexistent_dir_xyz"))?;
1342 assert!(result.is_empty());
1343 Ok(())
1344 }
1345
1346 #[test]
1347 fn test_timestamp_format() {
1348 let ts = simple_timestamp();
1349 assert!(ts.contains('T'));
1350 assert!(ts.ends_with('Z'));
1351 assert!(ts.len() >= 19);
1352 }
1353
1354 #[test]
1359 fn test_checkpoint_config_default() {
1360 let config = CheckpointConfig::default();
1361 assert_eq!(config.save_every, 1);
1362 assert_eq!(config.max_checkpoints, 5);
1363 assert!(config.save_best);
1364 assert_eq!(config.monitor_metric, "val_loss");
1365 assert!(config.minimize_metric);
1366 }
1367
1368 #[test]
1369 fn test_checkpoint_config_validate_ok() {
1370 let config = CheckpointConfig::default();
1371 assert!(config.validate().is_ok());
1372 }
1373
1374 #[test]
1375 fn test_checkpoint_config_validate_empty_metric() {
1376 let config = CheckpointConfig {
1377 monitor_metric: String::new(),
1378 ..Default::default()
1379 };
1380 assert!(config.validate().is_err());
1381 }
1382
1383 #[test]
1384 fn test_optimizer_checkpoint_state_adam() {
1385 let state = OptimizerCheckpointState::adam(0.001, 0.9, 0.999, 1e-8);
1386 assert_eq!(state.optimizer_type, "Adam");
1387 assert_eq!(state.learning_rate, 0.001);
1388 assert_eq!(state.beta1, Some(0.9));
1389 assert_eq!(state.beta2, Some(0.999));
1390 assert_eq!(state.epsilon, Some(1e-8));
1391 assert!(!state.has_moments());
1392 }
1393
1394 #[test]
1395 fn test_optimizer_checkpoint_state_sgd() {
1396 let state = OptimizerCheckpointState::sgd(0.01, 0.9, 0.0001);
1397 assert_eq!(state.optimizer_type, "SGD");
1398 assert_eq!(state.learning_rate, 0.01);
1399 assert_eq!(state.beta1, Some(0.9)); }
1401
1402 #[test]
1403 fn test_lr_scheduler_state_cosine() {
1404 let state = LrSchedulerState::cosine_annealing(0.001, 100);
1405 assert_eq!(state.scheduler_type, "CosineAnnealing");
1406 assert_eq!(state.base_lr, 0.001);
1407 assert_eq!(state.extra_params["t_max"], 100.0);
1408 }
1409
1410 #[test]
1411 fn test_lr_scheduler_state_step() {
1412 let state = LrSchedulerState::step_lr(0.01, 30, 0.1);
1413 assert_eq!(state.scheduler_type, "StepLR");
1414 assert_eq!(state.extra_params["step_size"], 30.0);
1415 assert_eq!(state.extra_params["gamma"], 0.1);
1416 }
1417
1418 #[test]
1419 fn test_training_checkpoint_new() {
1420 let ckpt = TrainingCheckpoint::new(5, 500, "ResNet50");
1421 assert_eq!(ckpt.epoch, 5);
1422 assert_eq!(ckpt.step, 500);
1423 assert_eq!(ckpt.architecture, "ResNet50");
1424 assert!(!ckpt.training_completed);
1425 assert!(ckpt.best_metric.is_none());
1426 }
1427
1428 #[test]
1429 fn test_training_checkpoint_latest_metric() {
1430 let mut ckpt = TrainingCheckpoint::new(3, 300, "BERT");
1431 let mut metrics = HashMap::new();
1432 metrics.insert("val_loss".to_string(), 0.35);
1433 metrics.insert("accuracy".to_string(), 0.88);
1434 ckpt.metrics_history.push(metrics);
1435
1436 assert_eq!(ckpt.latest_metric("val_loss"), Some(0.35));
1437 assert_eq!(ckpt.latest_metric("accuracy"), Some(0.88));
1438 assert!(ckpt.latest_metric("missing").is_none());
1439 }
1440
1441 #[test]
1442 fn test_checkpoint_manager_new() {
1443 let config = CheckpointConfig {
1444 save_dir: std::env::temp_dir().join("test_ckpt_mgr"),
1445 save_every: 5,
1446 max_checkpoints: 3,
1447 save_best: true,
1448 monitor_metric: "val_loss".to_string(),
1449 minimize_metric: true,
1450 };
1451 let manager: CheckpointManager<f64> = CheckpointManager::new(config.clone());
1452 assert_eq!(manager.config().save_every, 5);
1453 assert_eq!(manager.config().max_checkpoints, 3);
1454 assert!(manager.best_metric_value().is_none());
1455 assert!(manager.saved_checkpoint_paths().is_empty());
1456 }
1457
1458 #[test]
1459 fn test_checkpoint_manager_save_load_roundtrip() -> std::result::Result<(), CheckpointError> {
1460 let test_dir = std::env::temp_dir().join("scirs2_ckpt_mgr_test");
1461 let _ = fs::remove_dir_all(&test_dir);
1462
1463 let config = CheckpointConfig {
1464 save_dir: test_dir.clone(),
1465 save_every: 1,
1466 max_checkpoints: 10,
1467 save_best: true,
1468 monitor_metric: "val_loss".to_string(),
1469 minimize_metric: true,
1470 };
1471 let mut manager: CheckpointManager<f64> = CheckpointManager::new(config);
1472
1473 let mut params = NamedParameters::new();
1475 params.add("fc.weight", vec![1.0, 2.0, 3.0, 4.0], vec![2, 2]);
1476 params.add("fc.bias", vec![0.5, 0.5], vec![2]);
1477
1478 let mut ckpt = TrainingCheckpoint::new(10, 1000, "TestNet");
1480 ckpt.best_metric = Some(0.32);
1481 let mut epoch_metrics_map: HashMap<String, f64> = HashMap::new();
1482 epoch_metrics_map.insert("val_loss".to_string(), 0.32);
1483 ckpt.metrics_history.push(epoch_metrics_map);
1484
1485 let mut metrics: HashMap<String, f64> = HashMap::new();
1487 metrics.insert("val_loss".to_string(), 0.32);
1488
1489 let saved_path = manager.save(&ckpt, ¶ms, 10, &metrics)?;
1490
1491 assert!(saved_path.exists() || test_dir.exists());
1493
1494 let ckpt_path = test_dir.join("epoch_0010.ckpt");
1496 assert!(
1497 ckpt_path.exists(),
1498 "Checkpoint dir should exist: {:?}",
1499 ckpt_path
1500 );
1501
1502 let (loaded_ckpt, loaded_params) = CheckpointManager::<f64>::load(&ckpt_path)?;
1503 assert_eq!(loaded_ckpt.epoch, 10);
1504 assert_eq!(loaded_ckpt.step, 1000);
1505 assert_eq!(loaded_ckpt.architecture, "TestNet");
1506 assert_eq!(loaded_params.total_parameters(), 6); let best = CheckpointManager::<f64>::load_best(&test_dir)?;
1510 assert!(best.is_some());
1511 let (best_ckpt, _) = best.expect("best ckpt");
1512 assert_eq!(best_ckpt.epoch, 10);
1513
1514 let list = CheckpointManager::<f64>::list_checkpoints(&test_dir)?;
1516 assert_eq!(list.len(), 1);
1517
1518 let _ = fs::remove_dir_all(&test_dir);
1520 Ok(())
1521 }
1522
1523 #[test]
1524 fn test_checkpoint_manager_max_checkpoints_cleanup() -> std::result::Result<(), CheckpointError>
1525 {
1526 let test_dir = std::env::temp_dir().join("scirs2_ckpt_mgr_cleanup");
1527 let _ = fs::remove_dir_all(&test_dir);
1528
1529 let config = CheckpointConfig {
1530 save_dir: test_dir.clone(),
1531 save_every: 1,
1532 max_checkpoints: 2, save_best: false,
1534 monitor_metric: "val_loss".to_string(),
1535 minimize_metric: true,
1536 };
1537 let mut manager: CheckpointManager<f64> = CheckpointManager::new(config);
1538
1539 let mut params = NamedParameters::new();
1540 params.add("w", vec![1.0, 2.0], vec![2]);
1541
1542 let mut metrics: HashMap<String, f64> = HashMap::new();
1543 metrics.insert("val_loss".to_string(), 0.5);
1544
1545 for epoch in [0, 1, 2, 3] {
1547 let ckpt = TrainingCheckpoint::new(epoch, epoch * 100, "TestNet");
1548 manager.save(&ckpt, ¶ms, epoch, &metrics)?;
1549 }
1550
1551 assert_eq!(manager.saved_checkpoint_paths().len(), 2);
1553
1554 let list = CheckpointManager::<f64>::list_checkpoints(&test_dir)?;
1556 assert_eq!(list.len(), 2);
1557
1558 let _ = fs::remove_dir_all(&test_dir);
1559 Ok(())
1560 }
1561
1562 #[test]
1563 fn test_checkpoint_manager_save_best_tracking() -> std::result::Result<(), CheckpointError> {
1564 let test_dir = std::env::temp_dir().join("scirs2_ckpt_mgr_best");
1565 let _ = fs::remove_dir_all(&test_dir);
1566
1567 let config = CheckpointConfig {
1568 save_dir: test_dir.clone(),
1569 save_every: 1,
1570 max_checkpoints: 10,
1571 save_best: true,
1572 monitor_metric: "val_loss".to_string(),
1573 minimize_metric: true,
1574 };
1575 let mut manager: CheckpointManager<f64> = CheckpointManager::new(config);
1576
1577 let mut params = NamedParameters::new();
1578 params.add("w", vec![1.0], vec![1]);
1579
1580 let val_losses: Vec<f64> = vec![0.9, 0.7, 0.5, 0.6, 0.4, 0.45];
1581
1582 for (i, &val_loss) in val_losses.iter().enumerate() {
1583 let ckpt = TrainingCheckpoint::new(i, i * 100, "Net");
1584 let mut metrics = HashMap::new();
1585 metrics.insert("val_loss".to_string(), val_loss);
1586 manager.save(&ckpt, ¶ms, i, &metrics)?;
1587 }
1588
1589 assert_eq!(manager.best_metric_value(), Some(0.4));
1591
1592 let best = CheckpointManager::<f64>::load_best(&test_dir)?;
1593 assert!(best.is_some());
1594 let (best_ckpt, _) = best.expect("best ckpt");
1595 assert_eq!(best_ckpt.epoch, 4);
1596
1597 let _ = fs::remove_dir_all(&test_dir);
1598 Ok(())
1599 }
1600
1601 #[test]
1602 fn test_checkpoint_manager_load_latest() -> std::result::Result<(), CheckpointError> {
1603 let test_dir = std::env::temp_dir().join("scirs2_ckpt_mgr_latest");
1604 let _ = fs::remove_dir_all(&test_dir);
1605
1606 let config = CheckpointConfig {
1607 save_dir: test_dir.clone(),
1608 save_every: 1,
1609 max_checkpoints: 10,
1610 save_best: false,
1611 monitor_metric: "val_loss".to_string(),
1612 minimize_metric: true,
1613 };
1614 let mut manager: CheckpointManager<f64> = CheckpointManager::new(config);
1615
1616 let mut params = NamedParameters::new();
1617 params.add("w", vec![1.0], vec![1]);
1618 let mut metrics = HashMap::new();
1619 metrics.insert("val_loss".to_string(), 0.3f64);
1620
1621 for epoch in 0..5 {
1622 let ckpt = TrainingCheckpoint::new(epoch, epoch * 50, "Net");
1623 manager.save(&ckpt, ¶ms, epoch, &metrics)?;
1624 }
1625
1626 let latest = CheckpointManager::<f64>::load_latest(&test_dir)?;
1627 assert!(latest.is_some());
1628 let (latest_ckpt, _) = latest.expect("latest");
1629 assert_eq!(latest_ckpt.epoch, 4);
1630
1631 let _ = fs::remove_dir_all(&test_dir);
1632 Ok(())
1633 }
1634
1635 #[test]
1636 fn test_checkpoint_manager_load_best_no_best() -> std::result::Result<(), CheckpointError> {
1637 let test_dir = std::env::temp_dir().join("scirs2_ckpt_no_best");
1638 let _ = fs::remove_dir_all(&test_dir);
1639 let result = CheckpointManager::<f64>::load_best(&test_dir)?;
1640 assert!(result.is_none());
1641 Ok(())
1642 }
1643
1644 #[test]
1645 fn test_checkpoint_manager_load_latest_empty() -> std::result::Result<(), CheckpointError> {
1646 let test_dir = std::env::temp_dir().join("scirs2_ckpt_empty_latest");
1647 let _ = fs::remove_dir_all(&test_dir);
1648 let result = CheckpointManager::<f64>::load_latest(&test_dir)?;
1649 assert!(result.is_none());
1650 Ok(())
1651 }
1652
1653 #[test]
1654 fn test_checkpoint_error_display() {
1655 let err = CheckpointError::NotFound("/tmp/missing".to_string());
1656 let msg = err.to_string();
1657 assert!(msg.contains("missing"));
1658
1659 let err2 = CheckpointError::Serialization("bad json".to_string());
1660 let msg2 = err2.to_string();
1661 assert!(msg2.contains("bad json"));
1662 }
1663
1664 #[test]
1665 fn test_training_checkpoint_serialization_roundtrip() {
1666 let mut ckpt = TrainingCheckpoint::new(7, 700, "GPT");
1667 ckpt.best_metric = Some(0.28);
1668 ckpt.total_epochs = Some(50);
1669 ckpt.optimizer_state = OptimizerCheckpointState::adam(0.001, 0.9, 0.999, 1e-8);
1670 ckpt.lr_scheduler_state = Some(LrSchedulerState::cosine_annealing(0.001, 50));
1671
1672 let json = serde_json::to_string_pretty(&ckpt).expect("serialize");
1673 let restored: TrainingCheckpoint = serde_json::from_str(&json).expect("deserialize");
1674
1675 assert_eq!(restored.epoch, 7);
1676 assert_eq!(restored.step, 700);
1677 assert_eq!(restored.architecture, "GPT");
1678 assert_eq!(restored.best_metric, Some(0.28));
1679 assert_eq!(restored.total_epochs, Some(50));
1680 assert_eq!(restored.optimizer_state.optimizer_type, "Adam");
1681 assert!(restored.lr_scheduler_state.is_some());
1682 }
1683}