Skip to main content

scirs2_neural/training/
checkpoint.rs

1//! Training checkpoint support
2//!
3//! This module provides checkpoint functionality for saving and restoring
4//! mid-training state. A checkpoint captures the model weights, optimizer state,
5//! current epoch, training metrics, and any other state needed to resume training.
6//!
7//! ## Overview
8//!
9//! Checkpoints enable:
10//! - **Fault tolerance**: Resume training after crashes or interruptions
11//! - **Best model tracking**: Save the best model during training
12//! - **Training inspection**: Analyze training state at any point
13//! - **Transfer learning**: Start from a checkpoint with different training config
14//!
15//! ## Format
16//!
17//! Checkpoints are stored as a directory containing:
18//! - `model.safetensors` — Model weights in SafeTensors format
19//! - `checkpoint_meta.json` — Epoch, metrics, and optimizer state metadata
20//! - `optimizer_state.safetensors` — Optional optimizer moment vectors
21//!
22//! ## Usage
23//!
24//! ```rust
25//! use scirs2_neural::training::checkpoint::{CheckpointConfig, CheckpointManager};
26//! use std::path::PathBuf;
27//!
28//! let config = CheckpointConfig {
29//!     save_dir: PathBuf::from("/tmp/checkpoints"),
30//!     save_every: 5,
31//!     max_checkpoints: 3,
32//!     save_best: true,
33//!     monitor_metric: "val_loss".to_string(),
34//!     minimize_metric: true,
35//! };
36//!
37//! let manager = CheckpointManager::<f64>::new(config);
38//! ```
39
40use crate::error::{NeuralError, Result};
41use crate::serialization::safetensors::{SafeTensorsReader, SafeTensorsWriter};
42use crate::serialization::traits::{ModelMetadata, NamedParameters};
43use scirs2_core::ndarray::IxDyn;
44use scirs2_core::numeric::{Float, ToPrimitive};
45use serde::{Deserialize, Serialize};
46use std::collections::HashMap;
47use std::fmt::Debug;
48use std::fs;
49use std::marker::PhantomData;
50use std::path::{Path, PathBuf};
51
52// ============================================================================
53// Error type
54// ============================================================================
55
56/// Error type for checkpoint operations
57#[derive(Debug, thiserror::Error)]
58pub enum CheckpointError {
59    /// I/O error during checkpoint save/load
60    #[error("Checkpoint I/O error: {0}")]
61    Io(#[from] std::io::Error),
62
63    /// Serialization/deserialization error
64    #[error("Checkpoint serialization error: {0}")]
65    Serialization(String),
66
67    /// No checkpoint found in the specified directory
68    #[error("No checkpoint found in directory: {0}")]
69    NotFound(String),
70
71    /// Checkpoint format version mismatch
72    #[error("Checkpoint version mismatch: expected {expected}, found {found}")]
73    VersionMismatch { expected: String, found: String },
74
75    /// Invalid checkpoint configuration
76    #[error("Invalid checkpoint configuration: {0}")]
77    InvalidConfig(String),
78}
79
80impl From<CheckpointError> for NeuralError {
81    fn from(e: CheckpointError) -> Self {
82        NeuralError::IOError(e.to_string())
83    }
84}
85
86// ============================================================================
87// Checkpoint configuration
88// ============================================================================
89
90/// Configuration for the checkpoint manager
91#[derive(Debug, Clone)]
92pub struct CheckpointConfig {
93    /// Directory to save checkpoints in
94    pub save_dir: PathBuf,
95    /// Save a checkpoint every N epochs (0 = disabled)
96    pub save_every: usize,
97    /// Maximum number of checkpoints to keep (0 = keep all)
98    pub max_checkpoints: usize,
99    /// Save the best checkpoint separately as "best.ckpt/"
100    pub save_best: bool,
101    /// Metric to monitor for determining "best" checkpoint
102    /// (e.g., "val_loss", "val_accuracy")
103    pub monitor_metric: String,
104    /// If true, lower values of `monitor_metric` are considered better (e.g., loss)
105    /// If false, higher values are better (e.g., accuracy)
106    pub minimize_metric: bool,
107}
108
109impl Default for CheckpointConfig {
110    fn default() -> Self {
111        Self {
112            save_dir: PathBuf::from("checkpoints"),
113            save_every: 1,
114            max_checkpoints: 5,
115            save_best: true,
116            monitor_metric: "val_loss".to_string(),
117            minimize_metric: true,
118        }
119    }
120}
121
122impl CheckpointConfig {
123    /// Validate the configuration
124    pub fn validate(&self) -> std::result::Result<(), CheckpointError> {
125        if self.monitor_metric.is_empty() {
126            return Err(CheckpointError::InvalidConfig(
127                "monitor_metric must not be empty".to_string(),
128            ));
129        }
130        Ok(())
131    }
132}
133
134// ============================================================================
135// Optimizer checkpoint state
136// ============================================================================
137
138/// Serializable state for first-moment (m) and second-moment (v) Adam buffers,
139/// or SGD momentum buffers.
140///
141/// The moment vectors are stored as raw f64 data so we can serialize them
142/// to JSON. On load they are converted back to the concrete float type F.
143#[derive(Debug, Clone, Serialize, Deserialize)]
144pub struct OptimizerCheckpointState {
145    /// Optimizer type name (e.g., "Adam", "SGD", "AdamW")
146    pub optimizer_type: String,
147    /// Current learning rate
148    pub learning_rate: f64,
149    /// Step counter (used to correct Adam bias)
150    pub step: usize,
151    /// Beta1 (Adam-family) or momentum (SGD)
152    pub beta1: Option<f64>,
153    /// Beta2 (Adam-family)
154    pub beta2: Option<f64>,
155    /// Epsilon (Adam-family)
156    pub epsilon: Option<f64>,
157    /// Weight decay
158    pub weight_decay: f64,
159    /// First moment (m) vectors, keyed by parameter name
160    pub first_moments: HashMap<String, Vec<f64>>,
161    /// Second moment (v) vectors, keyed by parameter name
162    pub second_moments: HashMap<String, Vec<f64>>,
163    /// Parameter shapes, keyed by parameter name
164    pub param_shapes: HashMap<String, Vec<usize>>,
165}
166
167impl Default for OptimizerCheckpointState {
168    fn default() -> Self {
169        Self {
170            optimizer_type: "Unknown".to_string(),
171            learning_rate: 0.001,
172            step: 0,
173            beta1: None,
174            beta2: None,
175            epsilon: None,
176            weight_decay: 0.0,
177            first_moments: HashMap::new(),
178            second_moments: HashMap::new(),
179            param_shapes: HashMap::new(),
180        }
181    }
182}
183
184impl OptimizerCheckpointState {
185    /// Create optimizer state for an Adam optimizer
186    pub fn adam(learning_rate: f64, beta1: f64, beta2: f64, epsilon: f64) -> Self {
187        Self {
188            optimizer_type: "Adam".to_string(),
189            learning_rate,
190            step: 0,
191            beta1: Some(beta1),
192            beta2: Some(beta2),
193            epsilon: Some(epsilon),
194            weight_decay: 0.0,
195            first_moments: HashMap::new(),
196            second_moments: HashMap::new(),
197            param_shapes: HashMap::new(),
198        }
199    }
200
201    /// Create optimizer state for an SGD optimizer
202    pub fn sgd(learning_rate: f64, momentum: f64, weight_decay: f64) -> Self {
203        Self {
204            optimizer_type: "SGD".to_string(),
205            learning_rate,
206            step: 0,
207            beta1: Some(momentum),
208            beta2: None,
209            epsilon: None,
210            weight_decay,
211            first_moments: HashMap::new(),
212            second_moments: HashMap::new(),
213            param_shapes: HashMap::new(),
214        }
215    }
216
217    /// Returns true if this state has any moment vectors stored
218    pub fn has_moments(&self) -> bool {
219        !self.first_moments.is_empty() || !self.second_moments.is_empty()
220    }
221}
222
223// ============================================================================
224// Learning rate scheduler state
225// ============================================================================
226
227/// Serializable state for learning rate schedulers
228#[derive(Debug, Clone, Serialize, Deserialize)]
229pub struct LrSchedulerState {
230    /// Scheduler type name (e.g., "CosineAnnealing", "StepLR", "ReduceOnPlateau")
231    pub scheduler_type: String,
232    /// Current step within the scheduler
233    pub scheduler_step: usize,
234    /// Current learning rate
235    pub current_lr: f64,
236    /// Base (initial) learning rate
237    pub base_lr: f64,
238    /// Number of epochs without improvement (for ReduceOnPlateau)
239    pub patience_counter: usize,
240    /// Best monitored metric value seen so far
241    pub best_metric: Option<f64>,
242    /// Extra scheduler-specific parameters
243    pub extra_params: HashMap<String, f64>,
244}
245
246impl Default for LrSchedulerState {
247    fn default() -> Self {
248        Self {
249            scheduler_type: "Identity".to_string(),
250            scheduler_step: 0,
251            current_lr: 0.001,
252            base_lr: 0.001,
253            patience_counter: 0,
254            best_metric: None,
255            extra_params: HashMap::new(),
256        }
257    }
258}
259
260impl LrSchedulerState {
261    /// Create a state for a cosine annealing scheduler
262    pub fn cosine_annealing(base_lr: f64, t_max: usize) -> Self {
263        let mut extra = HashMap::new();
264        extra.insert("t_max".to_string(), t_max as f64);
265        Self {
266            scheduler_type: "CosineAnnealing".to_string(),
267            current_lr: base_lr,
268            base_lr,
269            extra_params: extra,
270            ..Default::default()
271        }
272    }
273
274    /// Create a state for a step LR scheduler
275    pub fn step_lr(base_lr: f64, step_size: usize, gamma: f64) -> Self {
276        let mut extra = HashMap::new();
277        extra.insert("step_size".to_string(), step_size as f64);
278        extra.insert("gamma".to_string(), gamma);
279        Self {
280            scheduler_type: "StepLR".to_string(),
281            current_lr: base_lr,
282            base_lr,
283            extra_params: extra,
284            ..Default::default()
285        }
286    }
287}
288
289// ============================================================================
290// Training checkpoint — full training state snapshot
291// ============================================================================
292
293/// A full snapshot of training state, sufficient to resume training identically.
294///
295/// Stored as a directory:
296/// - `checkpoint_meta.json` — all scalar fields and metrics
297/// - `model.safetensors` — model parameter tensors
298/// - `optimizer_state.json` — optimizer moment data
299#[derive(Debug, Clone, Serialize, Deserialize)]
300pub struct TrainingCheckpoint {
301    /// Current epoch number (0-indexed, points to the *completed* epoch)
302    pub epoch: usize,
303    /// Global step counter (total optimizer updates)
304    pub step: usize,
305    /// Best monitored metric value seen across all epochs so far
306    pub best_metric: Option<f64>,
307    /// Metrics history: one `HashMap<String, f64>` per epoch
308    pub metrics_history: Vec<HashMap<String, f64>>,
309    /// Optimizer state (serialized as JSON-compatible struct)
310    pub optimizer_state: OptimizerCheckpointState,
311    /// Learning rate scheduler state (if any)
312    pub lr_scheduler_state: Option<LrSchedulerState>,
313    /// Checkpoint format version for forward compatibility
314    pub format_version: String,
315    /// Architecture name of the saved model
316    pub architecture: String,
317    /// Timestamp when this checkpoint was created
318    pub timestamp: String,
319    /// Total number of epochs planned for training
320    pub total_epochs: Option<usize>,
321    /// Whether training completed without interruption
322    pub training_completed: bool,
323    /// Random seed used for reproducibility
324    pub random_seed: Option<u64>,
325}
326
327impl Default for TrainingCheckpoint {
328    fn default() -> Self {
329        Self {
330            epoch: 0,
331            step: 0,
332            best_metric: None,
333            metrics_history: Vec::new(),
334            optimizer_state: OptimizerCheckpointState::default(),
335            lr_scheduler_state: None,
336            format_version: "0.3.0".to_string(),
337            architecture: "Unknown".to_string(),
338            timestamp: simple_timestamp(),
339            total_epochs: None,
340            training_completed: false,
341            random_seed: None,
342        }
343    }
344}
345
346impl TrainingCheckpoint {
347    /// Create a new checkpoint for the given epoch
348    pub fn new(epoch: usize, step: usize, architecture: &str) -> Self {
349        Self {
350            epoch,
351            step,
352            architecture: architecture.to_string(),
353            timestamp: simple_timestamp(),
354            ..Default::default()
355        }
356    }
357
358    /// Retrieve the latest value of a metric from metrics_history
359    pub fn latest_metric(&self, name: &str) -> Option<f64> {
360        self.metrics_history
361            .last()
362            .and_then(|m| m.get(name).copied())
363    }
364
365    /// Mark training as completed
366    pub fn mark_completed(mut self) -> Self {
367        self.training_completed = true;
368        self
369    }
370}
371
372// ============================================================================
373// Checkpoint manager
374// ============================================================================
375
376/// Manages saving, loading, and cleanup of training checkpoints.
377///
378/// The manager tracks saved checkpoint paths and enforces the `max_checkpoints`
379/// limit by deleting the oldest checkpoint when a new one is saved.
380pub struct CheckpointManager<F: Float + Debug> {
381    /// Configuration
382    config: CheckpointConfig,
383    /// Paths of saved regular (non-best) checkpoints, oldest first
384    saved_checkpoints: Vec<PathBuf>,
385    /// Current best metric value (for best-model tracking)
386    best_metric_value: Option<f64>,
387    /// Phantom for the float type F
388    _phantom: PhantomData<F>,
389}
390
391impl<F: Float + Debug + ToPrimitive + 'static> CheckpointManager<F> {
392    /// Create a new checkpoint manager with the given configuration.
393    pub fn new(config: CheckpointConfig) -> Self {
394        Self {
395            config,
396            saved_checkpoints: Vec::new(),
397            best_metric_value: None,
398            _phantom: PhantomData,
399        }
400    }
401
402    /// Save a training checkpoint.
403    ///
404    /// Creates a directory named `epoch_{epoch:04}.ckpt/` inside `config.save_dir`.
405    /// If `config.save_best` is true and the monitored metric improved, also saves
406    /// a `best.ckpt/` symlink-style copy.
407    ///
408    /// # Arguments
409    ///
410    /// * `checkpoint` — Snapshot of training state (metadata, optimizer, scheduler)
411    /// * `model_params` — Named model parameters to persist in safetensors format
412    /// * `epoch` — Current epoch number (used to name the directory)
413    /// * `metrics` — Current epoch metrics map (e.g., `{"val_loss": 0.35}`)
414    ///
415    /// # Returns
416    ///
417    /// The path to the directory where the checkpoint was written.
418    pub fn save(
419        &mut self,
420        checkpoint: &TrainingCheckpoint,
421        model_params: &NamedParameters,
422        epoch: usize,
423        metrics: &HashMap<String, F>,
424    ) -> std::result::Result<PathBuf, CheckpointError> {
425        self.config.validate()?;
426
427        // Only save if save_every trigger fires
428        if self.config.save_every > 0 && !epoch.is_multiple_of(self.config.save_every) {
429            // Not a checkpoint epoch; still check if best
430            if self.config.save_best {
431                let _ = self.maybe_save_best(checkpoint, model_params, metrics)?;
432            }
433            return Ok(self.config.save_dir.clone());
434        }
435
436        fs::create_dir_all(&self.config.save_dir)?;
437
438        let dir_name = format!("epoch_{:04}.ckpt", epoch);
439        let ckpt_dir = self.config.save_dir.join(&dir_name);
440
441        self.write_checkpoint_to_dir(checkpoint, model_params, &ckpt_dir)?;
442
443        // Track and enforce max_checkpoints
444        self.saved_checkpoints.push(ckpt_dir.clone());
445        self.cleanup_old_checkpoints()?;
446
447        // Check if best
448        if self.config.save_best {
449            let _ = self.maybe_save_best(checkpoint, model_params, metrics)?;
450        }
451
452        Ok(ckpt_dir)
453    }
454
455    /// Load a training checkpoint from a specific directory path.
456    ///
457    /// # Returns
458    ///
459    /// A tuple of `(TrainingCheckpoint, NamedParameters)` where:
460    /// - `TrainingCheckpoint` contains all scalar metadata
461    /// - `NamedParameters` contains the model parameter tensors
462    pub fn load(
463        path: &Path,
464    ) -> std::result::Result<(TrainingCheckpoint, NamedParameters), CheckpointError> {
465        if !path.exists() {
466            return Err(CheckpointError::NotFound(path.display().to_string()));
467        }
468
469        // Load metadata JSON
470        let meta_path = path.join("checkpoint_meta.json");
471        if !meta_path.exists() {
472            return Err(CheckpointError::NotFound(format!(
473                "checkpoint_meta.json not found in {}",
474                path.display()
475            )));
476        }
477        let meta_bytes = fs::read(&meta_path)?;
478        let checkpoint: TrainingCheckpoint = serde_json::from_slice(&meta_bytes)
479            .map_err(|e| CheckpointError::Serialization(e.to_string()))?;
480
481        // Load model parameters via safetensors
482        let model_path = path.join("model.safetensors");
483        if !model_path.exists() {
484            return Err(CheckpointError::NotFound(format!(
485                "model.safetensors not found in {}",
486                path.display()
487            )));
488        }
489        let reader = SafeTensorsReader::from_file(&model_path)
490            .map_err(|e| CheckpointError::Serialization(e.to_string()))?;
491        let model_params = reader
492            .to_named_parameters()
493            .map_err(|e| CheckpointError::Serialization(e.to_string()))?;
494
495        Ok((checkpoint, model_params))
496    }
497
498    /// Load the latest checkpoint from the configured save directory.
499    ///
500    /// Returns `None` if no checkpoints exist.
501    pub fn load_latest(
502        save_dir: &Path,
503    ) -> std::result::Result<Option<(TrainingCheckpoint, NamedParameters)>, CheckpointError> {
504        let checkpoints = Self::list_checkpoints(save_dir)?;
505        match checkpoints.last() {
506            None => Ok(None),
507            Some(path) => {
508                let result = Self::load(path)?;
509                Ok(Some(result))
510            }
511        }
512    }
513
514    /// Load the best checkpoint from the configured save directory.
515    ///
516    /// Looks for a `best.ckpt/` directory in `save_dir`.
517    /// Returns `None` if no best checkpoint exists.
518    pub fn load_best(
519        save_dir: &Path,
520    ) -> std::result::Result<Option<(TrainingCheckpoint, NamedParameters)>, CheckpointError> {
521        let best_dir = save_dir.join("best.ckpt");
522        if !best_dir.exists() {
523            return Ok(None);
524        }
525        let result = Self::load(&best_dir)?;
526        Ok(Some(result))
527    }
528
529    /// List all regular checkpoint directories in a save directory, sorted by epoch.
530    ///
531    /// Only directories matching `epoch_NNNN.ckpt` pattern are included.
532    /// The `best.ckpt` directory is excluded.
533    pub fn list_checkpoints(save_dir: &Path) -> std::result::Result<Vec<PathBuf>, CheckpointError> {
534        if !save_dir.exists() {
535            return Ok(Vec::new());
536        }
537
538        let entries = fs::read_dir(save_dir)?;
539        let mut checkpoints: Vec<(usize, PathBuf)> = Vec::new();
540
541        for entry in entries {
542            let entry = entry?;
543            let path = entry.path();
544            if !path.is_dir() {
545                continue;
546            }
547            let file_name = match path.file_name().and_then(|n| n.to_str()) {
548                Some(n) => n.to_owned(),
549                None => continue,
550            };
551            // Match pattern epoch_NNNN.ckpt
552            if file_name.starts_with("epoch_") && file_name.ends_with(".ckpt") {
553                let epoch_part = file_name
554                    .trim_start_matches("epoch_")
555                    .trim_end_matches(".ckpt");
556                if let Ok(epoch) = epoch_part.parse::<usize>() {
557                    checkpoints.push((epoch, path));
558                }
559            }
560        }
561
562        // Sort by epoch number ascending
563        checkpoints.sort_by_key(|(epoch, _)| *epoch);
564        Ok(checkpoints.into_iter().map(|(_, p)| p).collect())
565    }
566
567    /// Delete old checkpoints, keeping only the `max_checkpoints` most recent.
568    fn cleanup_old_checkpoints(&mut self) -> std::result::Result<(), CheckpointError> {
569        if self.config.max_checkpoints == 0 {
570            return Ok(());
571        }
572
573        while self.saved_checkpoints.len() > self.config.max_checkpoints {
574            let oldest = self.saved_checkpoints.remove(0);
575            if oldest.exists() {
576                fs::remove_dir_all(&oldest)?;
577            }
578        }
579        Ok(())
580    }
581
582    /// Check if the current metrics improve on the best, and if so, save a `best.ckpt/` copy.
583    fn maybe_save_best(
584        &mut self,
585        checkpoint: &TrainingCheckpoint,
586        model_params: &NamedParameters,
587        metrics: &HashMap<String, F>,
588    ) -> std::result::Result<bool, CheckpointError> {
589        let metric_value = match metrics.get(&self.config.monitor_metric) {
590            Some(v) => match v.to_f64() {
591                Some(f) => f,
592                None => return Ok(false),
593            },
594            None => return Ok(false),
595        };
596
597        let is_better = match self.best_metric_value {
598            None => true,
599            Some(best) => {
600                if self.config.minimize_metric {
601                    metric_value < best
602                } else {
603                    metric_value > best
604                }
605            }
606        };
607
608        if is_better {
609            self.best_metric_value = Some(metric_value);
610            fs::create_dir_all(&self.config.save_dir)?;
611            let best_dir = self.config.save_dir.join("best.ckpt");
612            self.write_checkpoint_to_dir(checkpoint, model_params, &best_dir)?;
613            Ok(true)
614        } else {
615            Ok(false)
616        }
617    }
618
619    /// Internal: write checkpoint to a specific directory path.
620    fn write_checkpoint_to_dir(
621        &self,
622        checkpoint: &TrainingCheckpoint,
623        model_params: &NamedParameters,
624        dir: &Path,
625    ) -> std::result::Result<(), CheckpointError> {
626        fs::create_dir_all(dir)?;
627
628        // Write metadata JSON
629        let meta_json = serde_json::to_string_pretty(checkpoint)
630            .map_err(|e| CheckpointError::Serialization(e.to_string()))?;
631        fs::write(dir.join("checkpoint_meta.json"), meta_json.as_bytes())?;
632
633        // Write model parameters using SafeTensors
634        let model_path = dir.join("model.safetensors");
635        let meta = ModelMetadata::new(
636            &checkpoint.architecture,
637            "f64",
638            model_params.total_parameters(),
639        )
640        .with_extra("epoch", &checkpoint.epoch.to_string())
641        .with_extra("format_version", &checkpoint.format_version);
642
643        let mut writer = SafeTensorsWriter::new();
644        writer.add_model_metadata(&meta);
645        writer
646            .add_named_parameters(model_params)
647            .map_err(|e| CheckpointError::Serialization(e.to_string()))?;
648        writer
649            .write_to_file(&model_path)
650            .map_err(|e| CheckpointError::Serialization(e.to_string()))?;
651
652        // Write optimizer state as a separate JSON file for easy inspection
653        let opt_json = serde_json::to_string_pretty(&checkpoint.optimizer_state)
654            .map_err(|e| CheckpointError::Serialization(e.to_string()))?;
655        fs::write(dir.join("optimizer_state.json"), opt_json.as_bytes())?;
656
657        Ok(())
658    }
659
660    /// Get the current best metric value tracked by this manager.
661    pub fn best_metric_value(&self) -> Option<f64> {
662        self.best_metric_value
663    }
664
665    /// Get a reference to the configuration.
666    pub fn config(&self) -> &CheckpointConfig {
667        &self.config
668    }
669
670    /// Get the list of currently tracked checkpoint paths.
671    pub fn saved_checkpoint_paths(&self) -> &[PathBuf] {
672        &self.saved_checkpoints
673    }
674}
675
676// ============================================================================
677// Legacy / lower-level checkpoint functions (compatible with old API)
678// ============================================================================
679
680/// A training checkpoint capturing the full state needed to resume training
681///
682/// This is the legacy metadata structure — see [`TrainingCheckpoint`] for
683/// the richer v0.3.0 version.
684#[derive(Debug, Clone, Serialize, Deserialize)]
685pub struct CheckpointMetadata {
686    /// Current epoch number (0-indexed)
687    pub epoch: usize,
688    /// Global step counter
689    pub global_step: usize,
690    /// Current learning rate
691    pub learning_rate: f64,
692    /// Training loss at checkpoint time
693    pub train_loss: Option<f64>,
694    /// Validation loss at checkpoint time
695    pub val_loss: Option<f64>,
696    /// Best validation loss seen so far
697    pub best_val_loss: Option<f64>,
698    /// Additional training metrics (e.g., accuracy, F1)
699    pub metrics: HashMap<String, f64>,
700    /// Optimizer state metadata
701    pub optimizer_state: OptimizerStateMetadata,
702    /// Architecture name
703    pub architecture: String,
704    /// Model version string
705    pub model_version: String,
706    /// Timestamp when checkpoint was created
707    pub timestamp: String,
708    /// Whether training was completed or interrupted
709    pub training_completed: bool,
710    /// Total number of epochs planned
711    pub total_epochs: Option<usize>,
712    /// Random seed used for reproducibility
713    pub random_seed: Option<u64>,
714}
715
716impl Default for CheckpointMetadata {
717    fn default() -> Self {
718        Self {
719            epoch: 0,
720            global_step: 0,
721            learning_rate: 0.001,
722            train_loss: None,
723            val_loss: None,
724            best_val_loss: None,
725            metrics: HashMap::new(),
726            optimizer_state: OptimizerStateMetadata::default(),
727            architecture: "Unknown".to_string(),
728            model_version: "0.1.0".to_string(),
729            timestamp: simple_timestamp(),
730            training_completed: false,
731            total_epochs: None,
732            random_seed: None,
733        }
734    }
735}
736
737impl CheckpointMetadata {
738    /// Create a new checkpoint metadata with basic info
739    pub fn new(architecture: &str, epoch: usize, learning_rate: f64) -> Self {
740        Self {
741            architecture: architecture.to_string(),
742            epoch,
743            learning_rate,
744            timestamp: simple_timestamp(),
745            ..Default::default()
746        }
747    }
748
749    /// Set training loss
750    pub fn with_train_loss(mut self, loss: f64) -> Self {
751        self.train_loss = Some(loss);
752        self
753    }
754
755    /// Set validation loss
756    pub fn with_val_loss(mut self, loss: f64) -> Self {
757        self.val_loss = Some(loss);
758        self
759    }
760
761    /// Set best validation loss
762    pub fn with_best_val_loss(mut self, loss: f64) -> Self {
763        self.best_val_loss = Some(loss);
764        self
765    }
766
767    /// Add a metric
768    pub fn with_metric(mut self, name: &str, value: f64) -> Self {
769        self.metrics.insert(name.to_string(), value);
770        self
771    }
772
773    /// Set total epochs
774    pub fn with_total_epochs(mut self, total: usize) -> Self {
775        self.total_epochs = Some(total);
776        self
777    }
778
779    /// Set global step
780    pub fn with_global_step(mut self, step: usize) -> Self {
781        self.global_step = step;
782        self
783    }
784
785    /// Mark training as completed
786    pub fn mark_completed(mut self) -> Self {
787        self.training_completed = true;
788        self
789    }
790}
791
792/// Metadata for optimizer state
793#[derive(Debug, Clone, Serialize, Deserialize)]
794pub struct OptimizerStateMetadata {
795    /// Optimizer type name (e.g., "Adam", "SGD", "AdamW")
796    pub optimizer_type: String,
797    /// Number of parameter groups
798    pub num_param_groups: usize,
799    /// Per-parameter-group settings
800    pub param_groups: Vec<ParamGroupState>,
801}
802
803impl Default for OptimizerStateMetadata {
804    fn default() -> Self {
805        Self {
806            optimizer_type: "Unknown".to_string(),
807            num_param_groups: 0,
808            param_groups: Vec::new(),
809        }
810    }
811}
812
813impl OptimizerStateMetadata {
814    /// Create metadata for a simple optimizer
815    pub fn new(optimizer_type: &str) -> Self {
816        Self {
817            optimizer_type: optimizer_type.to_string(),
818            num_param_groups: 1,
819            param_groups: vec![ParamGroupState::default()],
820        }
821    }
822}
823
824/// State of a single parameter group in the optimizer
825#[derive(Debug, Clone, Serialize, Deserialize)]
826pub struct ParamGroupState {
827    /// Learning rate for this group
828    pub learning_rate: f64,
829    /// Weight decay
830    pub weight_decay: f64,
831    /// Momentum (for SGD-like optimizers)
832    pub momentum: Option<f64>,
833    /// Beta1 (for Adam-like optimizers)
834    pub beta1: Option<f64>,
835    /// Beta2 (for Adam-like optimizers)
836    pub beta2: Option<f64>,
837    /// Epsilon (for Adam-like optimizers)
838    pub epsilon: Option<f64>,
839    /// Step count for this group
840    pub step_count: usize,
841}
842
843impl Default for ParamGroupState {
844    fn default() -> Self {
845        Self {
846            learning_rate: 0.001,
847            weight_decay: 0.0,
848            momentum: None,
849            beta1: None,
850            beta2: None,
851            epsilon: None,
852            step_count: 0,
853        }
854    }
855}
856
857impl ParamGroupState {
858    /// Create state for an Adam optimizer
859    pub fn adam(learning_rate: f64, beta1: f64, beta2: f64, epsilon: f64) -> Self {
860        Self {
861            learning_rate,
862            weight_decay: 0.0,
863            momentum: None,
864            beta1: Some(beta1),
865            beta2: Some(beta2),
866            epsilon: Some(epsilon),
867            step_count: 0,
868        }
869    }
870
871    /// Create state for an SGD optimizer
872    pub fn sgd(learning_rate: f64, momentum: f64, weight_decay: f64) -> Self {
873        Self {
874            learning_rate,
875            weight_decay,
876            momentum: Some(momentum),
877            beta1: None,
878            beta2: None,
879            epsilon: None,
880            step_count: 0,
881        }
882    }
883}
884
885// ============================================================================
886// Lower-level save/load functions (legacy compatibility)
887// ============================================================================
888
889/// Save a training checkpoint to a directory.
890///
891/// Creates a directory containing:
892/// - `model.safetensors` — Model weights
893/// - `checkpoint_meta.json` — Training metadata
894/// - `optimizer_state.safetensors` — Optimizer moment vectors (optional)
895pub fn save_checkpoint(
896    checkpoint_dir: &Path,
897    model_params: &NamedParameters,
898    metadata: &CheckpointMetadata,
899    optimizer_moments: Option<&NamedParameters>,
900) -> Result<()> {
901    fs::create_dir_all(checkpoint_dir)
902        .map_err(|e| NeuralError::IOError(format!("Cannot create checkpoint directory: {e}")))?;
903
904    // Save model weights
905    let model_path = checkpoint_dir.join("model.safetensors");
906    let model_metadata = ModelMetadata::new(
907        &metadata.architecture,
908        "f64",
909        model_params.total_parameters(),
910    )
911    .with_extra("epoch", &metadata.epoch.to_string())
912    .with_extra("checkpoint", "true");
913
914    let mut writer = SafeTensorsWriter::new();
915    writer.add_model_metadata(&model_metadata);
916    writer.add_named_parameters(model_params)?;
917    writer.write_to_file(&model_path)?;
918
919    // Save checkpoint metadata
920    let meta_path = checkpoint_dir.join("checkpoint_meta.json");
921    let meta_json = serde_json::to_string_pretty(metadata)
922        .map_err(|e| NeuralError::SerializationError(format!("Cannot serialize metadata: {e}")))?;
923    fs::write(&meta_path, meta_json)
924        .map_err(|e| NeuralError::IOError(format!("Cannot write metadata: {e}")))?;
925
926    // Save optimizer state if provided
927    if let Some(moments) = optimizer_moments {
928        if !moments.is_empty() {
929            let optimizer_path = checkpoint_dir.join("optimizer_state.safetensors");
930            let opt_metadata = ModelMetadata::new("optimizer", "f64", moments.total_parameters());
931            let mut opt_writer = SafeTensorsWriter::new();
932            opt_writer.add_model_metadata(&opt_metadata);
933            opt_writer.add_named_parameters(moments)?;
934            opt_writer.write_to_file(&optimizer_path)?;
935        }
936    }
937
938    Ok(())
939}
940
941/// Load a training checkpoint from a directory.
942///
943/// Returns the model parameters, checkpoint metadata, and optional optimizer moments.
944pub fn load_checkpoint(
945    checkpoint_dir: &Path,
946) -> Result<(NamedParameters, CheckpointMetadata, Option<NamedParameters>)> {
947    if !checkpoint_dir.exists() {
948        return Err(NeuralError::IOError(format!(
949            "Checkpoint directory does not exist: {}",
950            checkpoint_dir.display()
951        )));
952    }
953
954    // Load model weights
955    let model_path = checkpoint_dir.join("model.safetensors");
956    if !model_path.exists() {
957        return Err(NeuralError::IOError(format!(
958            "Model weights not found at: {}",
959            model_path.display()
960        )));
961    }
962    let reader = SafeTensorsReader::from_file(&model_path)?;
963    let model_params = reader.to_named_parameters()?;
964
965    // Load metadata
966    let meta_path = checkpoint_dir.join("checkpoint_meta.json");
967    if !meta_path.exists() {
968        return Err(NeuralError::IOError(format!(
969            "Checkpoint metadata not found at: {}",
970            meta_path.display()
971        )));
972    }
973    let meta_json = fs::read_to_string(&meta_path)
974        .map_err(|e| NeuralError::IOError(format!("Cannot read metadata: {e}")))?;
975    let metadata: CheckpointMetadata = serde_json::from_str(&meta_json)
976        .map_err(|e| NeuralError::DeserializationError(format!("Invalid metadata: {e}")))?;
977
978    // Load optimizer state if available (safetensors format)
979    let optimizer_path = checkpoint_dir.join("optimizer_state.safetensors");
980    let optimizer_moments = if optimizer_path.exists() {
981        let opt_reader = SafeTensorsReader::from_file(&optimizer_path)?;
982        Some(opt_reader.to_named_parameters()?)
983    } else {
984        None
985    };
986
987    Ok((model_params, metadata, optimizer_moments))
988}
989
990/// List all checkpoints in a directory, sorted by epoch.
991///
992/// Expects checkpoint directories to be named like `checkpoint_epoch_NNNN`.
993pub fn list_checkpoints(base_dir: &Path) -> Result<Vec<(usize, PathBuf)>> {
994    if !base_dir.exists() {
995        return Ok(Vec::new());
996    }
997
998    let mut checkpoints = Vec::new();
999
1000    let entries = fs::read_dir(base_dir)
1001        .map_err(|e| NeuralError::IOError(format!("Cannot read directory: {e}")))?;
1002
1003    for entry in entries {
1004        let entry = entry.map_err(|e| NeuralError::IOError(format!("Cannot read entry: {e}")))?;
1005        let path = entry.path();
1006
1007        if path.is_dir() {
1008            let meta_path = path.join("checkpoint_meta.json");
1009            if meta_path.exists() {
1010                if let Ok(meta_json) = fs::read_to_string(&meta_path) {
1011                    if let Ok(meta) = serde_json::from_str::<CheckpointMetadata>(&meta_json) {
1012                        checkpoints.push((meta.epoch, path));
1013                    }
1014                }
1015            }
1016        }
1017    }
1018
1019    checkpoints.sort_by_key(|(epoch, _)| *epoch);
1020    Ok(checkpoints)
1021}
1022
1023/// Get the latest checkpoint in a directory
1024pub fn latest_checkpoint(base_dir: &Path) -> Result<Option<PathBuf>> {
1025    let checkpoints = list_checkpoints(base_dir)?;
1026    Ok(checkpoints.last().map(|(_, path)| path.clone()))
1027}
1028
1029/// Get the best checkpoint based on validation loss
1030pub fn best_checkpoint(base_dir: &Path) -> Result<Option<PathBuf>> {
1031    let checkpoints = list_checkpoints(base_dir)?;
1032
1033    let mut best: Option<(f64, PathBuf)> = None;
1034
1035    for (_, path) in &checkpoints {
1036        let meta_path = path.join("checkpoint_meta.json");
1037        if let Ok(meta_json) = fs::read_to_string(&meta_path) {
1038            if let Ok(meta) = serde_json::from_str::<CheckpointMetadata>(&meta_json) {
1039                if let Some(val_loss) = meta.val_loss {
1040                    match &best {
1041                        None => best = Some((val_loss, path.clone())),
1042                        Some((best_loss, _)) => {
1043                            if val_loss < *best_loss {
1044                                best = Some((val_loss, path.clone()));
1045                            }
1046                        }
1047                    }
1048                }
1049            }
1050        }
1051    }
1052
1053    Ok(best.map(|(_, path)| path))
1054}
1055
1056/// Create a checkpoint directory name from epoch number
1057pub fn checkpoint_dir_name(epoch: usize) -> String {
1058    format!("checkpoint_epoch_{epoch:04}")
1059}
1060
1061// ============================================================================
1062// Helper: simple timestamp without chrono dependency
1063// ============================================================================
1064
1065/// Generate a simple ISO-like timestamp string using `SystemTime`
1066fn simple_timestamp() -> String {
1067    let now = std::time::SystemTime::now();
1068    let duration = now
1069        .duration_since(std::time::UNIX_EPOCH)
1070        .unwrap_or_default();
1071    let secs = duration.as_secs();
1072
1073    let days = secs / 86400;
1074    let remaining = secs % 86400;
1075    let hours = remaining / 3600;
1076    let minutes = (remaining % 3600) / 60;
1077    let seconds = remaining % 60;
1078
1079    // Approximate date calculation (not calendar-accurate but unique)
1080    let years = 1970 + (days / 365);
1081    let day_in_year = days % 365;
1082    let month = (day_in_year / 30) + 1;
1083    let day = (day_in_year % 30) + 1;
1084
1085    format!("{years:04}-{month:02}-{day:02}T{hours:02}:{minutes:02}:{seconds:02}Z")
1086}
1087
1088// ============================================================================
1089// Tests
1090// ============================================================================
1091
1092#[cfg(test)]
1093mod tests {
1094    use super::*;
1095
1096    #[test]
1097    fn test_checkpoint_metadata_default() {
1098        let meta = CheckpointMetadata::default();
1099        assert_eq!(meta.epoch, 0);
1100        assert_eq!(meta.global_step, 0);
1101        assert!(!meta.training_completed);
1102        assert!(meta.train_loss.is_none());
1103        assert!(meta.val_loss.is_none());
1104    }
1105
1106    #[test]
1107    fn test_checkpoint_metadata_builder() {
1108        let meta = CheckpointMetadata::new("ResNet", 5, 0.001)
1109            .with_train_loss(0.25)
1110            .with_val_loss(0.30)
1111            .with_best_val_loss(0.28)
1112            .with_metric("accuracy", 0.92)
1113            .with_total_epochs(100)
1114            .with_global_step(5000);
1115
1116        assert_eq!(meta.architecture, "ResNet");
1117        assert_eq!(meta.epoch, 5);
1118        assert_eq!(meta.learning_rate, 0.001);
1119        assert_eq!(meta.train_loss, Some(0.25));
1120        assert_eq!(meta.val_loss, Some(0.30));
1121        assert_eq!(meta.best_val_loss, Some(0.28));
1122        assert_eq!(meta.metrics.get("accuracy"), Some(&0.92));
1123        assert_eq!(meta.total_epochs, Some(100));
1124        assert_eq!(meta.global_step, 5000);
1125    }
1126
1127    #[test]
1128    fn test_checkpoint_metadata_serialization() -> Result<()> {
1129        let meta = CheckpointMetadata::new("BERT", 10, 0.0001)
1130            .with_train_loss(0.15)
1131            .with_val_loss(0.20);
1132
1133        let json = serde_json::to_string_pretty(&meta)
1134            .map_err(|e| NeuralError::SerializationError(e.to_string()))?;
1135
1136        let restored: CheckpointMetadata = serde_json::from_str(&json)
1137            .map_err(|e| NeuralError::DeserializationError(e.to_string()))?;
1138
1139        assert_eq!(restored.architecture, "BERT");
1140        assert_eq!(restored.epoch, 10);
1141        assert_eq!(restored.train_loss, Some(0.15));
1142
1143        Ok(())
1144    }
1145
1146    #[test]
1147    fn test_optimizer_state_metadata() {
1148        let state = OptimizerStateMetadata::new("Adam");
1149        assert_eq!(state.optimizer_type, "Adam");
1150        assert_eq!(state.num_param_groups, 1);
1151    }
1152
1153    #[test]
1154    fn test_param_group_state_adam() {
1155        let pg = ParamGroupState::adam(0.001, 0.9, 0.999, 1e-8);
1156        assert_eq!(pg.learning_rate, 0.001);
1157        assert_eq!(pg.beta1, Some(0.9));
1158        assert_eq!(pg.beta2, Some(0.999));
1159        assert_eq!(pg.epsilon, Some(1e-8));
1160        assert!(pg.momentum.is_none());
1161    }
1162
1163    #[test]
1164    fn test_param_group_state_sgd() {
1165        let pg = ParamGroupState::sgd(0.01, 0.9, 0.0001);
1166        assert_eq!(pg.learning_rate, 0.01);
1167        assert_eq!(pg.momentum, Some(0.9));
1168        assert_eq!(pg.weight_decay, 0.0001);
1169        assert!(pg.beta1.is_none());
1170    }
1171
1172    #[test]
1173    fn test_save_load_checkpoint() -> Result<()> {
1174        let test_dir = std::env::temp_dir().join("scirs2_checkpoint_test");
1175        let checkpoint_dir = test_dir.join("checkpoint_epoch_0005");
1176
1177        // Clean up from any previous test runs
1178        let _ = fs::remove_dir_all(&test_dir);
1179
1180        // Create model parameters
1181        let mut params = NamedParameters::new();
1182        params.add("layer.0.weight", vec![1.0, 2.0, 3.0, 4.0], vec![2, 2]);
1183        params.add("layer.0.bias", vec![0.1, 0.2], vec![2]);
1184        params.add("layer.1.weight", vec![5.0, 6.0], vec![1, 2]);
1185        params.add("layer.1.bias", vec![0.3], vec![1]);
1186
1187        // Create optimizer moments
1188        let mut moments = NamedParameters::new();
1189        moments.add("layer.0.weight.m", vec![0.01, 0.02, 0.03, 0.04], vec![2, 2]);
1190        moments.add(
1191            "layer.0.weight.v",
1192            vec![0.001, 0.002, 0.003, 0.004],
1193            vec![2, 2],
1194        );
1195
1196        // Create metadata
1197        let meta = CheckpointMetadata::new("TestModel", 5, 0.001)
1198            .with_train_loss(0.25)
1199            .with_val_loss(0.30)
1200            .with_total_epochs(100);
1201
1202        // Save
1203        save_checkpoint(&checkpoint_dir, &params, &meta, Some(&moments))?;
1204
1205        // Verify files exist
1206        assert!(checkpoint_dir.join("model.safetensors").exists());
1207        assert!(checkpoint_dir.join("checkpoint_meta.json").exists());
1208        assert!(checkpoint_dir.join("optimizer_state.safetensors").exists());
1209
1210        // Load
1211        let (loaded_params, loaded_meta, loaded_moments) = load_checkpoint(&checkpoint_dir)?;
1212
1213        // Verify model params
1214        assert_eq!(loaded_params.len(), 4);
1215        assert_eq!(loaded_params.total_parameters(), 9); // 4+2+2+1
1216
1217        let (_, values, shape) = loaded_params
1218            .get("layer.0.weight")
1219            .ok_or_else(|| NeuralError::DeserializationError("not found".to_string()))?;
1220        assert_eq!(values, &[1.0, 2.0, 3.0, 4.0]);
1221        assert_eq!(shape, &[2, 2]);
1222
1223        // Verify metadata
1224        assert_eq!(loaded_meta.architecture, "TestModel");
1225        assert_eq!(loaded_meta.epoch, 5);
1226        assert_eq!(loaded_meta.learning_rate, 0.001);
1227        assert_eq!(loaded_meta.train_loss, Some(0.25));
1228        assert_eq!(loaded_meta.val_loss, Some(0.30));
1229        assert_eq!(loaded_meta.total_epochs, Some(100));
1230
1231        // Verify optimizer moments
1232        assert!(loaded_moments.is_some());
1233        let moments = loaded_moments.expect("should have moments");
1234        assert_eq!(moments.len(), 2);
1235
1236        // Clean up
1237        let _ = fs::remove_dir_all(&test_dir);
1238        Ok(())
1239    }
1240
1241    #[test]
1242    fn test_save_checkpoint_without_optimizer() -> Result<()> {
1243        let test_dir = std::env::temp_dir().join("scirs2_checkpoint_no_opt");
1244        let checkpoint_dir = test_dir.join("checkpoint_epoch_0001");
1245
1246        let _ = fs::remove_dir_all(&test_dir);
1247
1248        let mut params = NamedParameters::new();
1249        params.add("w", vec![1.0, 2.0], vec![2]);
1250
1251        let meta = CheckpointMetadata::new("Simple", 1, 0.01);
1252
1253        save_checkpoint(&checkpoint_dir, &params, &meta, None)?;
1254
1255        let (loaded_params, loaded_meta, loaded_moments) = load_checkpoint(&checkpoint_dir)?;
1256
1257        assert_eq!(loaded_params.len(), 1);
1258        assert_eq!(loaded_meta.epoch, 1);
1259        assert!(loaded_moments.is_none());
1260
1261        let _ = fs::remove_dir_all(&test_dir);
1262        Ok(())
1263    }
1264
1265    #[test]
1266    fn test_list_checkpoints() -> Result<()> {
1267        let test_dir = std::env::temp_dir().join("scirs2_list_checkpoints");
1268        let _ = fs::remove_dir_all(&test_dir);
1269
1270        for epoch in [1, 5, 10] {
1271            let dir_name = checkpoint_dir_name(epoch);
1272            let dir = test_dir.join(&dir_name);
1273
1274            let mut params = NamedParameters::new();
1275            params.add("w", vec![1.0], vec![1]);
1276
1277            let meta = CheckpointMetadata::new("Test", epoch, 0.001);
1278            save_checkpoint(&dir, &params, &meta, None)?;
1279        }
1280
1281        let checkpoints = list_checkpoints(&test_dir)?;
1282        assert_eq!(checkpoints.len(), 3);
1283        assert_eq!(checkpoints[0].0, 1);
1284        assert_eq!(checkpoints[1].0, 5);
1285        assert_eq!(checkpoints[2].0, 10);
1286
1287        // Test latest
1288        let latest = latest_checkpoint(&test_dir)?;
1289        assert!(latest.is_some());
1290
1291        let _ = fs::remove_dir_all(&test_dir);
1292        Ok(())
1293    }
1294
1295    #[test]
1296    fn test_best_checkpoint() -> Result<()> {
1297        let test_dir = std::env::temp_dir().join("scirs2_best_checkpoint");
1298        let _ = fs::remove_dir_all(&test_dir);
1299
1300        let losses = [(1, 0.50), (2, 0.35), (3, 0.30), (4, 0.32), (5, 0.28)];
1301
1302        for (epoch, val_loss) in &losses {
1303            let dir = test_dir.join(checkpoint_dir_name(*epoch));
1304            let mut params = NamedParameters::new();
1305            params.add("w", vec![1.0], vec![1]);
1306
1307            let meta = CheckpointMetadata::new("Test", *epoch, 0.001).with_val_loss(*val_loss);
1308
1309            save_checkpoint(&dir, &params, &meta, None)?;
1310        }
1311
1312        let best = best_checkpoint(&test_dir)?;
1313        assert!(best.is_some());
1314
1315        // Load the best checkpoint and verify it's epoch 5 (val_loss=0.28)
1316        let (_, meta, _) = load_checkpoint(&best.expect("best should exist"))?;
1317        assert_eq!(meta.epoch, 5);
1318        assert_eq!(meta.val_loss, Some(0.28));
1319
1320        let _ = fs::remove_dir_all(&test_dir);
1321        Ok(())
1322    }
1323
1324    #[test]
1325    fn test_checkpoint_dir_name() {
1326        assert_eq!(checkpoint_dir_name(0), "checkpoint_epoch_0000");
1327        assert_eq!(checkpoint_dir_name(1), "checkpoint_epoch_0001");
1328        assert_eq!(checkpoint_dir_name(42), "checkpoint_epoch_0042");
1329        assert_eq!(checkpoint_dir_name(999), "checkpoint_epoch_0999");
1330        assert_eq!(checkpoint_dir_name(10000), "checkpoint_epoch_10000");
1331    }
1332
1333    #[test]
1334    fn test_load_nonexistent_checkpoint() {
1335        let result = load_checkpoint(Path::new("/tmp/nonexistent_checkpoint_xyz"));
1336        assert!(result.is_err());
1337    }
1338
1339    #[test]
1340    fn test_list_empty_directory() -> Result<()> {
1341        let result = list_checkpoints(Path::new("/tmp/nonexistent_dir_xyz"))?;
1342        assert!(result.is_empty());
1343        Ok(())
1344    }
1345
1346    #[test]
1347    fn test_timestamp_format() {
1348        let ts = simple_timestamp();
1349        assert!(ts.contains('T'));
1350        assert!(ts.ends_with('Z'));
1351        assert!(ts.len() >= 19);
1352    }
1353
1354    // -----------------------------------------------------------------------
1355    // CheckpointConfig and CheckpointManager tests
1356    // -----------------------------------------------------------------------
1357
1358    #[test]
1359    fn test_checkpoint_config_default() {
1360        let config = CheckpointConfig::default();
1361        assert_eq!(config.save_every, 1);
1362        assert_eq!(config.max_checkpoints, 5);
1363        assert!(config.save_best);
1364        assert_eq!(config.monitor_metric, "val_loss");
1365        assert!(config.minimize_metric);
1366    }
1367
1368    #[test]
1369    fn test_checkpoint_config_validate_ok() {
1370        let config = CheckpointConfig::default();
1371        assert!(config.validate().is_ok());
1372    }
1373
1374    #[test]
1375    fn test_checkpoint_config_validate_empty_metric() {
1376        let config = CheckpointConfig {
1377            monitor_metric: String::new(),
1378            ..Default::default()
1379        };
1380        assert!(config.validate().is_err());
1381    }
1382
1383    #[test]
1384    fn test_optimizer_checkpoint_state_adam() {
1385        let state = OptimizerCheckpointState::adam(0.001, 0.9, 0.999, 1e-8);
1386        assert_eq!(state.optimizer_type, "Adam");
1387        assert_eq!(state.learning_rate, 0.001);
1388        assert_eq!(state.beta1, Some(0.9));
1389        assert_eq!(state.beta2, Some(0.999));
1390        assert_eq!(state.epsilon, Some(1e-8));
1391        assert!(!state.has_moments());
1392    }
1393
1394    #[test]
1395    fn test_optimizer_checkpoint_state_sgd() {
1396        let state = OptimizerCheckpointState::sgd(0.01, 0.9, 0.0001);
1397        assert_eq!(state.optimizer_type, "SGD");
1398        assert_eq!(state.learning_rate, 0.01);
1399        assert_eq!(state.beta1, Some(0.9)); // momentum stored in beta1
1400    }
1401
1402    #[test]
1403    fn test_lr_scheduler_state_cosine() {
1404        let state = LrSchedulerState::cosine_annealing(0.001, 100);
1405        assert_eq!(state.scheduler_type, "CosineAnnealing");
1406        assert_eq!(state.base_lr, 0.001);
1407        assert_eq!(state.extra_params["t_max"], 100.0);
1408    }
1409
1410    #[test]
1411    fn test_lr_scheduler_state_step() {
1412        let state = LrSchedulerState::step_lr(0.01, 30, 0.1);
1413        assert_eq!(state.scheduler_type, "StepLR");
1414        assert_eq!(state.extra_params["step_size"], 30.0);
1415        assert_eq!(state.extra_params["gamma"], 0.1);
1416    }
1417
1418    #[test]
1419    fn test_training_checkpoint_new() {
1420        let ckpt = TrainingCheckpoint::new(5, 500, "ResNet50");
1421        assert_eq!(ckpt.epoch, 5);
1422        assert_eq!(ckpt.step, 500);
1423        assert_eq!(ckpt.architecture, "ResNet50");
1424        assert!(!ckpt.training_completed);
1425        assert!(ckpt.best_metric.is_none());
1426    }
1427
1428    #[test]
1429    fn test_training_checkpoint_latest_metric() {
1430        let mut ckpt = TrainingCheckpoint::new(3, 300, "BERT");
1431        let mut metrics = HashMap::new();
1432        metrics.insert("val_loss".to_string(), 0.35);
1433        metrics.insert("accuracy".to_string(), 0.88);
1434        ckpt.metrics_history.push(metrics);
1435
1436        assert_eq!(ckpt.latest_metric("val_loss"), Some(0.35));
1437        assert_eq!(ckpt.latest_metric("accuracy"), Some(0.88));
1438        assert!(ckpt.latest_metric("missing").is_none());
1439    }
1440
1441    #[test]
1442    fn test_checkpoint_manager_new() {
1443        let config = CheckpointConfig {
1444            save_dir: std::env::temp_dir().join("test_ckpt_mgr"),
1445            save_every: 5,
1446            max_checkpoints: 3,
1447            save_best: true,
1448            monitor_metric: "val_loss".to_string(),
1449            minimize_metric: true,
1450        };
1451        let manager: CheckpointManager<f64> = CheckpointManager::new(config.clone());
1452        assert_eq!(manager.config().save_every, 5);
1453        assert_eq!(manager.config().max_checkpoints, 3);
1454        assert!(manager.best_metric_value().is_none());
1455        assert!(manager.saved_checkpoint_paths().is_empty());
1456    }
1457
1458    #[test]
1459    fn test_checkpoint_manager_save_load_roundtrip() -> std::result::Result<(), CheckpointError> {
1460        let test_dir = std::env::temp_dir().join("scirs2_ckpt_mgr_test");
1461        let _ = fs::remove_dir_all(&test_dir);
1462
1463        let config = CheckpointConfig {
1464            save_dir: test_dir.clone(),
1465            save_every: 1,
1466            max_checkpoints: 10,
1467            save_best: true,
1468            monitor_metric: "val_loss".to_string(),
1469            minimize_metric: true,
1470        };
1471        let mut manager: CheckpointManager<f64> = CheckpointManager::new(config);
1472
1473        // Build model params
1474        let mut params = NamedParameters::new();
1475        params.add("fc.weight", vec![1.0, 2.0, 3.0, 4.0], vec![2, 2]);
1476        params.add("fc.bias", vec![0.5, 0.5], vec![2]);
1477
1478        // Build checkpoint
1479        let mut ckpt = TrainingCheckpoint::new(10, 1000, "TestNet");
1480        ckpt.best_metric = Some(0.32);
1481        let mut epoch_metrics_map: HashMap<String, f64> = HashMap::new();
1482        epoch_metrics_map.insert("val_loss".to_string(), 0.32);
1483        ckpt.metrics_history.push(epoch_metrics_map);
1484
1485        // Metrics as F-typed (f64)
1486        let mut metrics: HashMap<String, f64> = HashMap::new();
1487        metrics.insert("val_loss".to_string(), 0.32);
1488
1489        let saved_path = manager.save(&ckpt, &params, 10, &metrics)?;
1490
1491        // Verify saved directory exists
1492        assert!(saved_path.exists() || test_dir.exists());
1493
1494        // Load it back
1495        let ckpt_path = test_dir.join("epoch_0010.ckpt");
1496        assert!(
1497            ckpt_path.exists(),
1498            "Checkpoint dir should exist: {:?}",
1499            ckpt_path
1500        );
1501
1502        let (loaded_ckpt, loaded_params) = CheckpointManager::<f64>::load(&ckpt_path)?;
1503        assert_eq!(loaded_ckpt.epoch, 10);
1504        assert_eq!(loaded_ckpt.step, 1000);
1505        assert_eq!(loaded_ckpt.architecture, "TestNet");
1506        assert_eq!(loaded_params.total_parameters(), 6); // 4 + 2
1507
1508        // Best checkpoint should also be saved
1509        let best = CheckpointManager::<f64>::load_best(&test_dir)?;
1510        assert!(best.is_some());
1511        let (best_ckpt, _) = best.expect("best ckpt");
1512        assert_eq!(best_ckpt.epoch, 10);
1513
1514        // list_checkpoints
1515        let list = CheckpointManager::<f64>::list_checkpoints(&test_dir)?;
1516        assert_eq!(list.len(), 1);
1517
1518        // Clean up
1519        let _ = fs::remove_dir_all(&test_dir);
1520        Ok(())
1521    }
1522
1523    #[test]
1524    fn test_checkpoint_manager_max_checkpoints_cleanup() -> std::result::Result<(), CheckpointError>
1525    {
1526        let test_dir = std::env::temp_dir().join("scirs2_ckpt_mgr_cleanup");
1527        let _ = fs::remove_dir_all(&test_dir);
1528
1529        let config = CheckpointConfig {
1530            save_dir: test_dir.clone(),
1531            save_every: 1,
1532            max_checkpoints: 2, // Keep only 2
1533            save_best: false,
1534            monitor_metric: "val_loss".to_string(),
1535            minimize_metric: true,
1536        };
1537        let mut manager: CheckpointManager<f64> = CheckpointManager::new(config);
1538
1539        let mut params = NamedParameters::new();
1540        params.add("w", vec![1.0, 2.0], vec![2]);
1541
1542        let mut metrics: HashMap<String, f64> = HashMap::new();
1543        metrics.insert("val_loss".to_string(), 0.5);
1544
1545        // Save 4 checkpoints - only 2 should remain
1546        for epoch in [0, 1, 2, 3] {
1547            let ckpt = TrainingCheckpoint::new(epoch, epoch * 100, "TestNet");
1548            manager.save(&ckpt, &params, epoch, &metrics)?;
1549        }
1550
1551        // Only max_checkpoints (2) should remain
1552        assert_eq!(manager.saved_checkpoint_paths().len(), 2);
1553
1554        // The saved ones should be the newest
1555        let list = CheckpointManager::<f64>::list_checkpoints(&test_dir)?;
1556        assert_eq!(list.len(), 2);
1557
1558        let _ = fs::remove_dir_all(&test_dir);
1559        Ok(())
1560    }
1561
1562    #[test]
1563    fn test_checkpoint_manager_save_best_tracking() -> std::result::Result<(), CheckpointError> {
1564        let test_dir = std::env::temp_dir().join("scirs2_ckpt_mgr_best");
1565        let _ = fs::remove_dir_all(&test_dir);
1566
1567        let config = CheckpointConfig {
1568            save_dir: test_dir.clone(),
1569            save_every: 1,
1570            max_checkpoints: 10,
1571            save_best: true,
1572            monitor_metric: "val_loss".to_string(),
1573            minimize_metric: true,
1574        };
1575        let mut manager: CheckpointManager<f64> = CheckpointManager::new(config);
1576
1577        let mut params = NamedParameters::new();
1578        params.add("w", vec![1.0], vec![1]);
1579
1580        let val_losses: Vec<f64> = vec![0.9, 0.7, 0.5, 0.6, 0.4, 0.45];
1581
1582        for (i, &val_loss) in val_losses.iter().enumerate() {
1583            let ckpt = TrainingCheckpoint::new(i, i * 100, "Net");
1584            let mut metrics = HashMap::new();
1585            metrics.insert("val_loss".to_string(), val_loss);
1586            manager.save(&ckpt, &params, i, &metrics)?;
1587        }
1588
1589        // Best should be 0.4 (epoch 4)
1590        assert_eq!(manager.best_metric_value(), Some(0.4));
1591
1592        let best = CheckpointManager::<f64>::load_best(&test_dir)?;
1593        assert!(best.is_some());
1594        let (best_ckpt, _) = best.expect("best ckpt");
1595        assert_eq!(best_ckpt.epoch, 4);
1596
1597        let _ = fs::remove_dir_all(&test_dir);
1598        Ok(())
1599    }
1600
1601    #[test]
1602    fn test_checkpoint_manager_load_latest() -> std::result::Result<(), CheckpointError> {
1603        let test_dir = std::env::temp_dir().join("scirs2_ckpt_mgr_latest");
1604        let _ = fs::remove_dir_all(&test_dir);
1605
1606        let config = CheckpointConfig {
1607            save_dir: test_dir.clone(),
1608            save_every: 1,
1609            max_checkpoints: 10,
1610            save_best: false,
1611            monitor_metric: "val_loss".to_string(),
1612            minimize_metric: true,
1613        };
1614        let mut manager: CheckpointManager<f64> = CheckpointManager::new(config);
1615
1616        let mut params = NamedParameters::new();
1617        params.add("w", vec![1.0], vec![1]);
1618        let mut metrics = HashMap::new();
1619        metrics.insert("val_loss".to_string(), 0.3f64);
1620
1621        for epoch in 0..5 {
1622            let ckpt = TrainingCheckpoint::new(epoch, epoch * 50, "Net");
1623            manager.save(&ckpt, &params, epoch, &metrics)?;
1624        }
1625
1626        let latest = CheckpointManager::<f64>::load_latest(&test_dir)?;
1627        assert!(latest.is_some());
1628        let (latest_ckpt, _) = latest.expect("latest");
1629        assert_eq!(latest_ckpt.epoch, 4);
1630
1631        let _ = fs::remove_dir_all(&test_dir);
1632        Ok(())
1633    }
1634
1635    #[test]
1636    fn test_checkpoint_manager_load_best_no_best() -> std::result::Result<(), CheckpointError> {
1637        let test_dir = std::env::temp_dir().join("scirs2_ckpt_no_best");
1638        let _ = fs::remove_dir_all(&test_dir);
1639        let result = CheckpointManager::<f64>::load_best(&test_dir)?;
1640        assert!(result.is_none());
1641        Ok(())
1642    }
1643
1644    #[test]
1645    fn test_checkpoint_manager_load_latest_empty() -> std::result::Result<(), CheckpointError> {
1646        let test_dir = std::env::temp_dir().join("scirs2_ckpt_empty_latest");
1647        let _ = fs::remove_dir_all(&test_dir);
1648        let result = CheckpointManager::<f64>::load_latest(&test_dir)?;
1649        assert!(result.is_none());
1650        Ok(())
1651    }
1652
1653    #[test]
1654    fn test_checkpoint_error_display() {
1655        let err = CheckpointError::NotFound("/tmp/missing".to_string());
1656        let msg = err.to_string();
1657        assert!(msg.contains("missing"));
1658
1659        let err2 = CheckpointError::Serialization("bad json".to_string());
1660        let msg2 = err2.to_string();
1661        assert!(msg2.contains("bad json"));
1662    }
1663
1664    #[test]
1665    fn test_training_checkpoint_serialization_roundtrip() {
1666        let mut ckpt = TrainingCheckpoint::new(7, 700, "GPT");
1667        ckpt.best_metric = Some(0.28);
1668        ckpt.total_epochs = Some(50);
1669        ckpt.optimizer_state = OptimizerCheckpointState::adam(0.001, 0.9, 0.999, 1e-8);
1670        ckpt.lr_scheduler_state = Some(LrSchedulerState::cosine_annealing(0.001, 50));
1671
1672        let json = serde_json::to_string_pretty(&ckpt).expect("serialize");
1673        let restored: TrainingCheckpoint = serde_json::from_str(&json).expect("deserialize");
1674
1675        assert_eq!(restored.epoch, 7);
1676        assert_eq!(restored.step, 700);
1677        assert_eq!(restored.architecture, "GPT");
1678        assert_eq!(restored.best_metric, Some(0.28));
1679        assert_eq!(restored.total_epochs, Some(50));
1680        assert_eq!(restored.optimizer_state.optimizer_type, "Adam");
1681        assert!(restored.lr_scheduler_state.is_some());
1682    }
1683}