Skip to main content

tensorlogic_train/callbacks/
checkpoint.rs

1//! Checkpoint callbacks for saving and loading training state.
2
3use crate::callbacks::core::Callback;
4use crate::{TrainError, TrainResult, TrainingState};
5use flate2::read::GzDecoder;
6use flate2::write::GzEncoder;
7use flate2::Compression;
8use std::collections::HashMap;
9use std::fs::File;
10use std::io::{Read, Write};
11use std::path::PathBuf;
12
13/// Compression method for checkpoints.
14#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
15pub enum CheckpointCompression {
16    /// No compression (plain JSON).
17    #[default]
18    None,
19    /// Gzip compression (good balance of speed and ratio).
20    Gzip,
21    /// Fast gzip compression (faster but lower ratio).
22    GzipFast,
23    /// Best gzip compression (slower but better ratio).
24    GzipBest,
25}
26
27/// Comprehensive checkpoint data structure.
28///
29/// This structure contains all the information needed to fully restore
30/// training state, including model parameters, optimizer state, and training history.
31#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
32pub struct TrainingCheckpoint {
33    /// Current epoch number.
34    pub epoch: usize,
35    /// Model parameters as flattened vectors.
36    pub parameters: HashMap<String, Vec<f64>>,
37    /// Optimizer state as flattened vectors.
38    pub optimizer_state: HashMap<String, Vec<f64>>,
39    /// Scheduler state (if present).
40    pub scheduler_state: Option<HashMap<String, f64>>,
41    /// Current training loss.
42    pub train_loss: f64,
43    /// Current validation loss (if available).
44    pub val_loss: Option<f64>,
45    /// Training loss history.
46    pub train_loss_history: Vec<f64>,
47    /// Validation loss history.
48    pub val_loss_history: Vec<f64>,
49    /// Metrics history.
50    pub metrics_history: HashMap<String, Vec<f64>>,
51    /// Current learning rate.
52    pub learning_rate: f64,
53    /// Best validation loss seen so far.
54    pub best_val_loss: Option<f64>,
55}
56
57impl TrainingCheckpoint {
58    /// Create a new checkpoint from current training state.
59    #[allow(clippy::too_many_arguments)]
60    pub fn new(
61        epoch: usize,
62        parameters: &HashMap<String, scirs2_core::ndarray::Array<f64, scirs2_core::ndarray::Ix2>>,
63        optimizer_state: &HashMap<String, Vec<f64>>,
64        scheduler_state: Option<HashMap<String, f64>>,
65        state: &TrainingState,
66        train_loss_history: &[f64],
67        val_loss_history: &[f64],
68        metrics_history: &HashMap<String, Vec<f64>>,
69        best_val_loss: Option<f64>,
70    ) -> Self {
71        // Convert parameters to flat vectors
72        let parameters = parameters
73            .iter()
74            .map(|(name, param)| (name.clone(), param.iter().copied().collect()))
75            .collect();
76
77        Self {
78            epoch,
79            parameters,
80            optimizer_state: optimizer_state.clone(),
81            scheduler_state,
82            train_loss: state.train_loss,
83            val_loss: state.val_loss,
84            train_loss_history: train_loss_history.to_vec(),
85            val_loss_history: val_loss_history.to_vec(),
86            metrics_history: metrics_history.clone(),
87            learning_rate: state.learning_rate,
88            best_val_loss,
89        }
90    }
91
92    /// Save checkpoint to a file.
93    pub fn save(&self, path: &PathBuf) -> TrainResult<()> {
94        self.save_with_compression(path, CheckpointCompression::None)
95    }
96
97    /// Save checkpoint to a file with compression.
98    ///
99    /// # Arguments
100    /// * `path` - Path to save the checkpoint
101    /// * `compression` - Compression method to use
102    ///
103    /// # Example
104    /// ```no_run
105    /// use tensorlogic_train::TrainingCheckpoint;
106    /// use tensorlogic_train::CheckpointCompression;
107    /// use std::path::PathBuf;
108    ///
109    /// // Assuming you have a checkpoint...
110    /// # let checkpoint: TrainingCheckpoint = unimplemented!();
111    ///
112    /// // Save with gzip compression
113    /// checkpoint.save_with_compression(
114    ///     &PathBuf::from("/tmp/checkpoint.json.gz"),
115    ///     CheckpointCompression::Gzip
116    /// ).unwrap();
117    /// ```
118    pub fn save_with_compression(
119        &self,
120        path: &PathBuf,
121        compression: CheckpointCompression,
122    ) -> TrainResult<()> {
123        let json = serde_json::to_string_pretty(self).map_err(|e| {
124            TrainError::CheckpointError(format!("Failed to serialize checkpoint: {}", e))
125        })?;
126
127        if let Some(parent) = path.parent() {
128            std::fs::create_dir_all(parent).map_err(|e| {
129                TrainError::CheckpointError(format!("Failed to create checkpoint directory: {}", e))
130            })?;
131        }
132
133        match compression {
134            CheckpointCompression::None => {
135                std::fs::write(path, json).map_err(|e| {
136                    TrainError::CheckpointError(format!("Failed to write checkpoint: {}", e))
137                })?;
138            }
139            CheckpointCompression::Gzip => {
140                let file = File::create(path).map_err(|e| {
141                    TrainError::CheckpointError(format!("Failed to create checkpoint file: {}", e))
142                })?;
143                let mut encoder = GzEncoder::new(file, Compression::default());
144                encoder.write_all(json.as_bytes()).map_err(|e| {
145                    TrainError::CheckpointError(format!("Failed to compress checkpoint: {}", e))
146                })?;
147                encoder.finish().map_err(|e| {
148                    TrainError::CheckpointError(format!("Failed to finish compression: {}", e))
149                })?;
150            }
151            CheckpointCompression::GzipFast => {
152                let file = File::create(path).map_err(|e| {
153                    TrainError::CheckpointError(format!("Failed to create checkpoint file: {}", e))
154                })?;
155                let mut encoder = GzEncoder::new(file, Compression::fast());
156                encoder.write_all(json.as_bytes()).map_err(|e| {
157                    TrainError::CheckpointError(format!("Failed to compress checkpoint: {}", e))
158                })?;
159                encoder.finish().map_err(|e| {
160                    TrainError::CheckpointError(format!("Failed to finish compression: {}", e))
161                })?;
162            }
163            CheckpointCompression::GzipBest => {
164                let file = File::create(path).map_err(|e| {
165                    TrainError::CheckpointError(format!("Failed to create checkpoint file: {}", e))
166                })?;
167                let mut encoder = GzEncoder::new(file, Compression::best());
168                encoder.write_all(json.as_bytes()).map_err(|e| {
169                    TrainError::CheckpointError(format!("Failed to compress checkpoint: {}", e))
170                })?;
171                encoder.finish().map_err(|e| {
172                    TrainError::CheckpointError(format!("Failed to finish compression: {}", e))
173                })?;
174            }
175        }
176
177        Ok(())
178    }
179
180    /// Load checkpoint from a file.
181    pub fn load(path: &PathBuf) -> TrainResult<Self> {
182        // Auto-detect compression based on file extension
183        if path.to_string_lossy().ends_with(".gz") {
184            Self::load_compressed(path)
185        } else {
186            Self::load_uncompressed(path)
187        }
188    }
189
190    /// Load uncompressed checkpoint from a file.
191    fn load_uncompressed(path: &PathBuf) -> TrainResult<Self> {
192        let json = std::fs::read_to_string(path).map_err(|e| {
193            TrainError::CheckpointError(format!("Failed to read checkpoint: {}", e))
194        })?;
195
196        let checkpoint: Self = serde_json::from_str(&json).map_err(|e| {
197            TrainError::CheckpointError(format!("Failed to deserialize checkpoint: {}", e))
198        })?;
199
200        Ok(checkpoint)
201    }
202
203    /// Load compressed checkpoint from a file.
204    pub fn load_compressed(path: &PathBuf) -> TrainResult<Self> {
205        let file = File::open(path).map_err(|e| {
206            TrainError::CheckpointError(format!("Failed to open checkpoint file: {}", e))
207        })?;
208
209        let mut decoder = GzDecoder::new(file);
210        let mut json = String::new();
211        decoder.read_to_string(&mut json).map_err(|e| {
212            TrainError::CheckpointError(format!("Failed to decompress checkpoint: {}", e))
213        })?;
214
215        let checkpoint: Self = serde_json::from_str(&json).map_err(|e| {
216            TrainError::CheckpointError(format!("Failed to deserialize checkpoint: {}", e))
217        })?;
218
219        Ok(checkpoint)
220    }
221
222    /// Get the size of the checkpoint in bytes (estimated).
223    pub fn estimated_size(&self) -> usize {
224        // Rough estimate: parameters + optimizer_state + histories
225        let param_size: usize = self
226            .parameters
227            .values()
228            .map(|v| v.len() * std::mem::size_of::<f64>())
229            .sum();
230        let optimizer_size: usize = self
231            .optimizer_state
232            .values()
233            .map(|v| v.len() * std::mem::size_of::<f64>())
234            .sum();
235        let history_size = (self.train_loss_history.len() + self.val_loss_history.len())
236            * std::mem::size_of::<f64>();
237
238        param_size + optimizer_size + history_size
239    }
240}
241
242/// Checkpoint metadata for tracking saved checkpoints.
243#[derive(Debug, Clone, PartialEq)]
244struct CheckpointMetadata {
245    /// Epoch number.
246    epoch: usize,
247    /// Validation loss (if available).
248    val_loss: Option<f64>,
249    /// File path.
250    path: PathBuf,
251}
252
253/// Callback for model checkpointing with auto-cleanup.
254pub struct CheckpointCallback {
255    /// Directory to save checkpoints.
256    pub checkpoint_dir: PathBuf,
257    /// Frequency of checkpointing (every N epochs).
258    pub save_frequency: usize,
259    /// Whether to save only the best model.
260    pub save_best_only: bool,
261    /// Maximum number of checkpoints to keep (None = keep all).
262    pub keep_top_k: Option<usize>,
263    /// Best validation loss seen so far.
264    best_val_loss: Option<f64>,
265    /// Metadata of saved checkpoints for cleanup.
266    saved_checkpoints: Vec<CheckpointMetadata>,
267}
268
269impl CheckpointCallback {
270    /// Create a new checkpoint callback.
271    pub fn new(checkpoint_dir: PathBuf, save_frequency: usize, save_best_only: bool) -> Self {
272        Self {
273            checkpoint_dir,
274            save_frequency,
275            save_best_only,
276            keep_top_k: None,
277            best_val_loss: None,
278            saved_checkpoints: Vec::new(),
279        }
280    }
281
282    /// Create a new checkpoint callback with auto-cleanup.
283    ///
284    /// This will automatically delete old checkpoints when the number exceeds `keep_top_k`,
285    /// keeping only the checkpoints with the best (lowest) validation loss.
286    ///
287    /// # Arguments
288    /// * `checkpoint_dir` - Directory to save checkpoints
289    /// * `save_frequency` - Save every N epochs
290    /// * `save_best_only` - Only save when validation loss improves
291    /// * `keep_top_k` - Maximum number of checkpoints to keep (keeps best by validation loss)
292    ///
293    /// # Example
294    /// ```no_run
295    /// use tensorlogic_train::CheckpointCallback;
296    /// use std::path::PathBuf;
297    ///
298    /// // Keep only the top 5 best checkpoints
299    /// let callback = CheckpointCallback::with_cleanup(
300    ///     PathBuf::from("/tmp/checkpoints"),
301    ///     1,    // save every epoch
302    ///     false, // save all, not just best
303    ///     5     // keep top 5
304    /// );
305    /// ```
306    pub fn with_cleanup(
307        checkpoint_dir: PathBuf,
308        save_frequency: usize,
309        save_best_only: bool,
310        keep_top_k: usize,
311    ) -> Self {
312        Self {
313            checkpoint_dir,
314            save_frequency,
315            save_best_only,
316            keep_top_k: Some(keep_top_k),
317            best_val_loss: None,
318            saved_checkpoints: Vec::new(),
319        }
320    }
321
322    /// Get the number of saved checkpoints being tracked.
323    pub fn num_saved_checkpoints(&self) -> usize {
324        self.saved_checkpoints.len()
325    }
326
327    /// Manually cleanup checkpoints, keeping only the top-k best.
328    ///
329    /// This can be called manually to trigger cleanup if you've changed the
330    /// `keep_top_k` setting.
331    pub fn cleanup_checkpoints(&mut self) -> TrainResult<usize> {
332        let keep_top_k = match self.keep_top_k {
333            Some(k) => k,
334            None => return Ok(0), // No cleanup needed
335        };
336
337        if self.saved_checkpoints.len() <= keep_top_k {
338            return Ok(0); // Don't need to clean up yet
339        }
340
341        // Sort by validation loss (ascending - best first)
342        // For checkpoints without val_loss, prefer more recent epochs
343        self.saved_checkpoints.sort_by(|a, b| {
344            match (a.val_loss, b.val_loss) {
345                (Some(a_loss), Some(b_loss)) => a_loss
346                    .partial_cmp(&b_loss)
347                    .unwrap_or(std::cmp::Ordering::Equal),
348                (Some(_), None) => std::cmp::Ordering::Less, // Prefer checkpoints with val_loss
349                (None, Some(_)) => std::cmp::Ordering::Greater, // Prefer checkpoints with val_loss
350                (None, None) => b.epoch.cmp(&a.epoch),       // Prefer newer epochs (descending)
351            }
352        });
353
354        // Remove checkpoints beyond top-k
355        let to_remove: Vec<CheckpointMetadata> =
356            self.saved_checkpoints.drain(keep_top_k..).collect();
357
358        let mut deleted_count = 0;
359        for checkpoint in to_remove {
360            if let Err(e) = std::fs::remove_file(&checkpoint.path) {
361                eprintln!(
362                    "Warning: Failed to delete checkpoint {:?}: {}",
363                    checkpoint.path, e
364                );
365            } else {
366                deleted_count += 1;
367            }
368        }
369
370        Ok(deleted_count)
371    }
372
373    /// Save checkpoint to disk (legacy simple format).
374    fn save_checkpoint(&mut self, epoch: usize, state: &TrainingState) -> TrainResult<()> {
375        let checkpoint_path = self
376            .checkpoint_dir
377            .join(format!("checkpoint_epoch_{}.json", epoch));
378
379        // Create checkpoint data
380        let mut checkpoint = HashMap::new();
381        checkpoint.insert("epoch".to_string(), epoch as f64);
382        checkpoint.insert("train_loss".to_string(), state.train_loss);
383        if let Some(val_loss) = state.val_loss {
384            checkpoint.insert("val_loss".to_string(), val_loss);
385        }
386
387        // Save to JSON
388        let json = serde_json::to_string_pretty(&checkpoint).map_err(|e| {
389            TrainError::CheckpointError(format!("Failed to serialize checkpoint: {}", e))
390        })?;
391
392        std::fs::create_dir_all(&self.checkpoint_dir).map_err(|e| {
393            TrainError::CheckpointError(format!("Failed to create checkpoint directory: {}", e))
394        })?;
395
396        std::fs::write(&checkpoint_path, json).map_err(|e| {
397            TrainError::CheckpointError(format!("Failed to write checkpoint: {}", e))
398        })?;
399
400        // Track checkpoint metadata
401        let metadata = CheckpointMetadata {
402            epoch,
403            val_loss: state.val_loss,
404            path: checkpoint_path.clone(),
405        };
406        self.saved_checkpoints.push(metadata);
407
408        // Auto-cleanup if needed
409        if self.keep_top_k.is_some() {
410            let deleted = self.cleanup_checkpoints()?;
411            if deleted > 0 {
412                println!(
413                    "Checkpoint saved to {:?} (deleted {} old checkpoints)",
414                    checkpoint_path, deleted
415                );
416            } else {
417                println!("Checkpoint saved to {:?}", checkpoint_path);
418            }
419        } else {
420            println!("Checkpoint saved to {:?}", checkpoint_path);
421        }
422
423        Ok(())
424    }
425}
426
427impl Callback for CheckpointCallback {
428    fn on_epoch_end(&mut self, epoch: usize, state: &TrainingState) -> TrainResult<()> {
429        if !epoch.is_multiple_of(self.save_frequency) {
430            return Ok(());
431        }
432
433        if self.save_best_only {
434            if let Some(val_loss) = state.val_loss {
435                let should_save = self
436                    .best_val_loss
437                    .map(|best| val_loss < best)
438                    .unwrap_or(true);
439
440                if should_save {
441                    self.best_val_loss = Some(val_loss);
442                    self.save_checkpoint(epoch, state)?;
443                }
444            }
445        } else {
446            self.save_checkpoint(epoch, state)?;
447        }
448
449        Ok(())
450    }
451}
452
453#[cfg(test)]
454mod tests {
455    use super::*;
456    use scirs2_core::ndarray::Array2;
457    use std::env::temp_dir;
458
459    fn create_test_state() -> TrainingState {
460        TrainingState {
461            epoch: 0,
462            batch: 0,
463            train_loss: 1.0,
464            val_loss: Some(0.8),
465            batch_loss: 0.5,
466            learning_rate: 0.001,
467            metrics: HashMap::new(),
468        }
469    }
470
471    #[test]
472    fn test_checkpoint_callback() {
473        let checkpoint_dir = temp_dir().join("tensorlogic_test_checkpoints");
474        let mut callback = CheckpointCallback::new(checkpoint_dir.clone(), 1, false);
475        let state = create_test_state();
476
477        callback.on_epoch_end(0, &state).unwrap();
478
479        // Verify checkpoint was created
480        let checkpoint_path = checkpoint_dir.join("checkpoint_epoch_0.json");
481        assert!(checkpoint_path.exists());
482
483        // Clean up
484        std::fs::remove_dir_all(checkpoint_dir).ok();
485    }
486
487    #[test]
488    fn test_training_checkpoint_save_load() {
489        // Create test parameters
490        let mut parameters = HashMap::new();
491        parameters.insert("weight".to_string(), Array2::from_elem((2, 3), 1.5));
492        parameters.insert("bias".to_string(), Array2::from_elem((1, 3), 0.5));
493
494        // Create test state
495        let state = TrainingState {
496            epoch: 5,
497            batch: 100,
498            train_loss: 0.75,
499            val_loss: Some(0.85),
500            batch_loss: 0.72,
501            learning_rate: 0.001,
502            metrics: HashMap::new(),
503        };
504
505        // Create optimizer state (mock)
506        let optimizer_state = {
507            let mut state = HashMap::new();
508            state.insert("momentum_weight".to_string(), vec![0.1, 0.2, 0.3]);
509            state.insert("momentum_bias".to_string(), vec![0.05]);
510            state
511        };
512
513        // Create checkpoint
514        let checkpoint = TrainingCheckpoint::new(
515            5,
516            &parameters,
517            &optimizer_state,
518            None,
519            &state,
520            &[1.0, 0.9, 0.8, 0.77, 0.75],
521            &[1.1, 0.95, 0.88, 0.87, 0.85],
522            &HashMap::new(),
523            Some(0.85),
524        );
525
526        // Save checkpoint
527        let checkpoint_path = temp_dir().join("test_training_checkpoint.json");
528        checkpoint.save(&checkpoint_path).unwrap();
529
530        // Verify file exists
531        assert!(checkpoint_path.exists());
532
533        // Load checkpoint
534        let loaded = TrainingCheckpoint::load(&checkpoint_path).unwrap();
535
536        // Verify data
537        assert_eq!(loaded.epoch, 5);
538        assert_eq!(loaded.train_loss, 0.75);
539        assert_eq!(loaded.val_loss, Some(0.85));
540        assert_eq!(loaded.learning_rate, 0.001);
541        assert_eq!(loaded.train_loss_history.len(), 5);
542        assert_eq!(loaded.val_loss_history.len(), 5);
543        assert_eq!(loaded.best_val_loss, Some(0.85));
544
545        // Verify parameters
546        assert_eq!(loaded.parameters.len(), 2);
547        assert!(loaded.parameters.contains_key("weight"));
548        assert!(loaded.parameters.contains_key("bias"));
549
550        // Verify optimizer state
551        assert_eq!(loaded.optimizer_state.len(), 2);
552        assert!(loaded.optimizer_state.contains_key("momentum_weight"));
553
554        // Clean up
555        std::fs::remove_file(checkpoint_path).ok();
556    }
557
558    #[test]
559    fn test_training_checkpoint_with_metrics() {
560        let mut parameters = HashMap::new();
561        parameters.insert("w".to_string(), Array2::zeros((2, 2)));
562
563        let state = create_test_state();
564        let optimizer_state = HashMap::new();
565
566        // Add metrics history
567        let mut metrics_history = HashMap::new();
568        metrics_history.insert("accuracy".to_string(), vec![0.5, 0.6, 0.7]);
569        metrics_history.insert("f1_score".to_string(), vec![0.45, 0.55, 0.65]);
570
571        let checkpoint = TrainingCheckpoint::new(
572            2,
573            &parameters,
574            &optimizer_state,
575            None,
576            &state,
577            &[1.0, 0.8, 0.6],
578            &[1.1, 0.9, 0.7],
579            &metrics_history,
580            Some(0.7),
581        );
582
583        let checkpoint_path = temp_dir().join("test_checkpoint_with_metrics.json");
584        checkpoint.save(&checkpoint_path).unwrap();
585
586        let loaded = TrainingCheckpoint::load(&checkpoint_path).unwrap();
587
588        // Verify metrics
589        assert_eq!(loaded.metrics_history.len(), 2);
590        assert!(loaded.metrics_history.contains_key("accuracy"));
591        assert!(loaded.metrics_history.contains_key("f1_score"));
592        assert_eq!(loaded.metrics_history["accuracy"].len(), 3);
593
594        std::fs::remove_file(checkpoint_path).ok();
595    }
596
597    #[test]
598    fn test_checkpoint_compression_gzip() {
599        let mut parameters = HashMap::new();
600        parameters.insert("weights".to_string(), Array2::from_elem((100, 100), 1.5));
601
602        let state = create_test_state();
603        let optimizer_state = HashMap::new();
604
605        let checkpoint = TrainingCheckpoint::new(
606            10,
607            &parameters,
608            &optimizer_state,
609            None,
610            &state,
611            &vec![1.0; 100],
612            &vec![0.9; 100],
613            &HashMap::new(),
614            Some(0.5),
615        );
616
617        // Save with gzip compression
618        let compressed_path = temp_dir().join("test_checkpoint_compressed.json.gz");
619        checkpoint
620            .save_with_compression(&compressed_path, CheckpointCompression::Gzip)
621            .unwrap();
622
623        // Verify compressed file exists
624        assert!(compressed_path.exists());
625
626        // Load compressed checkpoint
627        let loaded = TrainingCheckpoint::load(&compressed_path).unwrap();
628
629        // Verify data
630        assert_eq!(loaded.epoch, 10);
631        assert_eq!(loaded.parameters.len(), 1);
632        assert_eq!(loaded.parameters["weights"].len(), 10000); // 100x100
633
634        // Compare file sizes
635        let uncompressed_path = temp_dir().join("test_checkpoint_uncompressed.json");
636        checkpoint.save(&uncompressed_path).unwrap();
637
638        let compressed_size = std::fs::metadata(&compressed_path).unwrap().len();
639        let uncompressed_size = std::fs::metadata(&uncompressed_path).unwrap().len();
640
641        // Compressed should be smaller
642        assert!(
643            compressed_size < uncompressed_size,
644            "Compressed size {} should be less than uncompressed size {}",
645            compressed_size,
646            uncompressed_size
647        );
648
649        // Clean up
650        std::fs::remove_file(compressed_path).ok();
651        std::fs::remove_file(uncompressed_path).ok();
652    }
653
654    #[test]
655    fn test_checkpoint_compression_fast_vs_best() {
656        let mut parameters = HashMap::new();
657        parameters.insert("weights".to_string(), Array2::from_elem((50, 50), 2.0));
658
659        let state = create_test_state();
660        let optimizer_state = HashMap::new();
661
662        let checkpoint = TrainingCheckpoint::new(
663            5,
664            &parameters,
665            &optimizer_state,
666            None,
667            &state,
668            &vec![1.0; 50],
669            &vec![0.8; 50],
670            &HashMap::new(),
671            None,
672        );
673
674        // Save with fast compression
675        let fast_path = temp_dir().join("test_checkpoint_fast.json.gz");
676        checkpoint
677            .save_with_compression(&fast_path, CheckpointCompression::GzipFast)
678            .unwrap();
679
680        // Save with best compression
681        let best_path = temp_dir().join("test_checkpoint_best.json.gz");
682        checkpoint
683            .save_with_compression(&best_path, CheckpointCompression::GzipBest)
684            .unwrap();
685
686        // Both should be loadable
687        let loaded_fast = TrainingCheckpoint::load(&fast_path).unwrap();
688        let loaded_best = TrainingCheckpoint::load(&best_path).unwrap();
689
690        assert_eq!(loaded_fast.epoch, 5);
691        assert_eq!(loaded_best.epoch, 5);
692        assert_eq!(
693            loaded_fast.parameters["weights"],
694            loaded_best.parameters["weights"]
695        );
696
697        // Clean up
698        std::fs::remove_file(fast_path).ok();
699        std::fs::remove_file(best_path).ok();
700    }
701
702    #[test]
703    fn test_checkpoint_estimated_size() {
704        let mut parameters = HashMap::new();
705        parameters.insert("w1".to_string(), Array2::from_elem((10, 10), 1.0));
706        parameters.insert("w2".to_string(), Array2::from_elem((5, 5), 1.0));
707
708        let state = create_test_state();
709        let optimizer_state = HashMap::new();
710
711        let train_loss_history: [f64; 10] = [1.0; 10];
712        let val_loss_history: [f64; 10] = [0.9; 10];
713        let checkpoint = TrainingCheckpoint::new(
714            1,
715            &parameters,
716            &optimizer_state,
717            None,
718            &state,
719            &train_loss_history,
720            &val_loss_history,
721            &HashMap::new(),
722            None,
723        );
724
725        let size = checkpoint.estimated_size();
726        // 100 + 25 = 125 parameters * 8 bytes + 20 history entries * 8 bytes
727        assert!(size > 0);
728        assert_eq!(
729            size,
730            (100 + 25) * std::mem::size_of::<f64>() + 20 * std::mem::size_of::<f64>()
731        );
732    }
733
734    #[test]
735    fn test_checkpoint_auto_detect_compression() {
736        let mut parameters = HashMap::new();
737        parameters.insert("w".to_string(), Array2::from_elem((5, 5), 1.0));
738
739        let state = create_test_state();
740
741        let checkpoint = TrainingCheckpoint::new(
742            1,
743            &parameters,
744            &HashMap::new(),
745            None,
746            &state,
747            &[1.0],
748            &[0.9],
749            &HashMap::new(),
750            None,
751        );
752
753        // Save uncompressed
754        let uncompressed_path = temp_dir().join("test_auto_detect.json");
755        checkpoint.save(&uncompressed_path).unwrap();
756
757        // Save compressed
758        let compressed_path = temp_dir().join("test_auto_detect.json.gz");
759        checkpoint
760            .save_with_compression(&compressed_path, CheckpointCompression::Gzip)
761            .unwrap();
762
763        // Load both using auto-detection
764        let loaded_uncompressed = TrainingCheckpoint::load(&uncompressed_path).unwrap();
765        let loaded_compressed = TrainingCheckpoint::load(&compressed_path).unwrap();
766
767        assert_eq!(loaded_uncompressed.epoch, loaded_compressed.epoch);
768        assert_eq!(loaded_uncompressed.parameters, loaded_compressed.parameters);
769
770        // Clean up
771        std::fs::remove_file(uncompressed_path).ok();
772        std::fs::remove_file(compressed_path).ok();
773    }
774
775    #[test]
776    fn test_checkpoint_auto_cleanup() {
777        let checkpoint_dir = temp_dir().join("tensorlogic_test_auto_cleanup");
778        std::fs::create_dir_all(&checkpoint_dir).ok();
779
780        // Create callback with keep_top_k = 3
781        let mut callback = CheckpointCallback::with_cleanup(checkpoint_dir.clone(), 1, false, 3);
782
783        // Save 5 checkpoints with different validation losses
784        let val_losses = [0.9, 0.7, 0.8, 0.6, 0.5]; // Best is 0.5, then 0.6, then 0.7
785
786        for (epoch, &val_loss) in val_losses.iter().enumerate() {
787            let mut state = create_test_state();
788            state.val_loss = Some(val_loss);
789            callback.save_checkpoint(epoch, &state).unwrap();
790        }
791
792        // Should only have 3 checkpoints remaining (top 3 best)
793        assert_eq!(callback.num_saved_checkpoints(), 3);
794
795        // Verify the best 3 checkpoints exist
796        assert!(checkpoint_dir.join("checkpoint_epoch_4.json").exists()); // val_loss = 0.5
797        assert!(checkpoint_dir.join("checkpoint_epoch_3.json").exists()); // val_loss = 0.6
798        assert!(checkpoint_dir.join("checkpoint_epoch_1.json").exists()); // val_loss = 0.7
799
800        // Verify the worst 2 were deleted
801        assert!(!checkpoint_dir.join("checkpoint_epoch_0.json").exists()); // val_loss = 0.9
802        assert!(!checkpoint_dir.join("checkpoint_epoch_2.json").exists()); // val_loss = 0.8
803
804        // Clean up
805        std::fs::remove_dir_all(checkpoint_dir).ok();
806    }
807
808    #[test]
809    fn test_checkpoint_no_cleanup_when_disabled() {
810        let checkpoint_dir = temp_dir().join("tensorlogic_test_no_cleanup");
811        std::fs::create_dir_all(&checkpoint_dir).ok();
812
813        // Create callback without cleanup (keep_top_k = None)
814        let mut callback = CheckpointCallback::new(checkpoint_dir.clone(), 1, false);
815
816        // Save 5 checkpoints
817        for epoch in 0..5 {
818            let state = create_test_state();
819            callback.save_checkpoint(epoch, &state).unwrap();
820        }
821
822        // All 5 checkpoints should still exist
823        for epoch in 0..5 {
824            let path = checkpoint_dir.join(format!("checkpoint_epoch_{}.json", epoch));
825            assert!(path.exists(), "Checkpoint {} should exist", epoch);
826        }
827
828        // Clean up
829        std::fs::remove_dir_all(checkpoint_dir).ok();
830    }
831
832    #[test]
833    fn test_checkpoint_manual_cleanup() {
834        let checkpoint_dir = temp_dir().join("tensorlogic_test_manual_cleanup");
835        std::fs::create_dir_all(&checkpoint_dir).ok();
836
837        // Create callback with keep_top_k = 2
838        let mut callback = CheckpointCallback::with_cleanup(checkpoint_dir.clone(), 1, false, 2);
839
840        // Save 4 checkpoints
841        let val_losses = [0.8, 0.6, 0.9, 0.5];
842        for (epoch, &val_loss) in val_losses.iter().enumerate() {
843            let mut state = create_test_state();
844            state.val_loss = Some(val_loss);
845            callback.save_checkpoint(epoch, &state).unwrap();
846        }
847
848        // Should have only top 2
849        assert_eq!(callback.num_saved_checkpoints(), 2);
850
851        // Manually trigger cleanup (should do nothing since we're already at top-2)
852        let deleted = callback.cleanup_checkpoints().unwrap();
853        assert_eq!(deleted, 0);
854        assert_eq!(callback.num_saved_checkpoints(), 2);
855
856        // Clean up
857        std::fs::remove_dir_all(checkpoint_dir).ok();
858    }
859
860    #[test]
861    fn test_checkpoint_cleanup_without_val_loss() {
862        let checkpoint_dir = temp_dir().join("tensorlogic_test_cleanup_no_val_loss");
863        std::fs::create_dir_all(&checkpoint_dir).ok();
864
865        // Create callback with keep_top_k = 2
866        let mut callback = CheckpointCallback::with_cleanup(checkpoint_dir.clone(), 1, false, 2);
867
868        // Save 4 checkpoints without validation loss
869        for epoch in 0..4 {
870            let mut state = create_test_state();
871            state.val_loss = None; // No validation loss
872            callback.save_checkpoint(epoch, &state).unwrap();
873        }
874
875        // Should keep top 2 (most recent by epoch)
876        assert_eq!(callback.num_saved_checkpoints(), 2);
877
878        // Verify most recent 2 epochs exist
879        assert!(checkpoint_dir.join("checkpoint_epoch_3.json").exists());
880        assert!(checkpoint_dir.join("checkpoint_epoch_2.json").exists());
881
882        // Clean up
883        std::fs::remove_dir_all(checkpoint_dir).ok();
884    }
885
886    #[test]
887    fn test_checkpoint_with_save_best_only_and_cleanup() {
888        let checkpoint_dir = temp_dir().join("tensorlogic_test_best_and_cleanup");
889        std::fs::create_dir_all(&checkpoint_dir).ok();
890
891        // Create callback with both save_best_only and keep_top_k
892        let mut callback = CheckpointCallback::with_cleanup(checkpoint_dir.clone(), 1, true, 2);
893
894        // Try to save checkpoints with improving and non-improving losses
895        let val_losses = [0.9, 0.7, 0.8, 0.6]; // 0.9 -> 0.7 (save), 0.8 (skip), 0.6 (save)
896
897        for (epoch, &val_loss) in val_losses.iter().enumerate() {
898            let mut state = create_test_state();
899            state.val_loss = Some(val_loss);
900            callback.on_epoch_end(epoch, &state).unwrap();
901        }
902
903        // Should only have saved the improving checkpoints (0.9, 0.7, 0.6), then cleaned up to top-2
904        assert!(callback.num_saved_checkpoints() <= 2);
905
906        // Clean up
907        std::fs::remove_dir_all(checkpoint_dir).ok();
908    }
909}