Skip to main content

tensorlogic_train/callbacks/
checkpoint.rs

1//! Checkpoint callbacks for saving and loading training state.
2
3use crate::callbacks::core::Callback;
4use crate::{TrainError, TrainResult, TrainingState};
5use std::collections::HashMap;
6use std::fs::File;
7use std::io::Read;
8use std::path::PathBuf;
9
10/// Compression method for checkpoints.
11#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
12pub enum CheckpointCompression {
13    /// No compression (plain JSON).
14    #[default]
15    None,
16    /// Gzip compression (good balance of speed and ratio).
17    Gzip,
18    /// Fast gzip compression (faster but lower ratio).
19    GzipFast,
20    /// Best gzip compression (slower but better ratio).
21    GzipBest,
22}
23
24/// Comprehensive checkpoint data structure.
25///
26/// This structure contains all the information needed to fully restore
27/// training state, including model parameters, optimizer state, and training history.
28#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
29pub struct TrainingCheckpoint {
30    /// Current epoch number.
31    pub epoch: usize,
32    /// Model parameters as flattened vectors.
33    pub parameters: HashMap<String, Vec<f64>>,
34    /// Optimizer state as flattened vectors.
35    pub optimizer_state: HashMap<String, Vec<f64>>,
36    /// Scheduler state (if present).
37    pub scheduler_state: Option<HashMap<String, f64>>,
38    /// Current training loss.
39    pub train_loss: f64,
40    /// Current validation loss (if available).
41    pub val_loss: Option<f64>,
42    /// Training loss history.
43    pub train_loss_history: Vec<f64>,
44    /// Validation loss history.
45    pub val_loss_history: Vec<f64>,
46    /// Metrics history.
47    pub metrics_history: HashMap<String, Vec<f64>>,
48    /// Current learning rate.
49    pub learning_rate: f64,
50    /// Best validation loss seen so far.
51    pub best_val_loss: Option<f64>,
52}
53
54impl TrainingCheckpoint {
55    /// Create a new checkpoint from current training state.
56    #[allow(clippy::too_many_arguments)]
57    pub fn new(
58        epoch: usize,
59        parameters: &HashMap<String, scirs2_core::ndarray::Array<f64, scirs2_core::ndarray::Ix2>>,
60        optimizer_state: &HashMap<String, Vec<f64>>,
61        scheduler_state: Option<HashMap<String, f64>>,
62        state: &TrainingState,
63        train_loss_history: &[f64],
64        val_loss_history: &[f64],
65        metrics_history: &HashMap<String, Vec<f64>>,
66        best_val_loss: Option<f64>,
67    ) -> Self {
68        // Convert parameters to flat vectors
69        let parameters = parameters
70            .iter()
71            .map(|(name, param)| (name.clone(), param.iter().copied().collect()))
72            .collect();
73
74        Self {
75            epoch,
76            parameters,
77            optimizer_state: optimizer_state.clone(),
78            scheduler_state,
79            train_loss: state.train_loss,
80            val_loss: state.val_loss,
81            train_loss_history: train_loss_history.to_vec(),
82            val_loss_history: val_loss_history.to_vec(),
83            metrics_history: metrics_history.clone(),
84            learning_rate: state.learning_rate,
85            best_val_loss,
86        }
87    }
88
89    /// Save checkpoint to a file.
90    pub fn save(&self, path: &PathBuf) -> TrainResult<()> {
91        self.save_with_compression(path, CheckpointCompression::None)
92    }
93
94    /// Save checkpoint to a file with compression.
95    ///
96    /// # Arguments
97    /// * `path` - Path to save the checkpoint
98    /// * `compression` - Compression method to use
99    ///
100    /// # Example
101    /// ```no_run
102    /// use tensorlogic_train::TrainingCheckpoint;
103    /// use tensorlogic_train::CheckpointCompression;
104    /// use std::path::PathBuf;
105    ///
106    /// // Assuming you have a checkpoint...
107    /// # let checkpoint: TrainingCheckpoint = unimplemented!();
108    ///
109    /// // Save with gzip compression
110    /// checkpoint.save_with_compression(
111    ///     &PathBuf::from("/tmp/checkpoint.json.gz"),
112    ///     CheckpointCompression::Gzip
113    /// ).expect("unwrap");
114    /// ```
115    pub fn save_with_compression(
116        &self,
117        path: &PathBuf,
118        compression: CheckpointCompression,
119    ) -> TrainResult<()> {
120        let json = serde_json::to_string_pretty(self).map_err(|e| {
121            TrainError::CheckpointError(format!("Failed to serialize checkpoint: {}", e))
122        })?;
123
124        if let Some(parent) = path.parent() {
125            std::fs::create_dir_all(parent).map_err(|e| {
126                TrainError::CheckpointError(format!("Failed to create checkpoint directory: {}", e))
127            })?;
128        }
129
130        match compression {
131            CheckpointCompression::None => {
132                std::fs::write(path, json).map_err(|e| {
133                    TrainError::CheckpointError(format!("Failed to write checkpoint: {}", e))
134                })?;
135            }
136            CheckpointCompression::Gzip => {
137                // Level 6: balanced speed/ratio (matches flate2 Compression::default())
138                let compressed =
139                    oxiarc_deflate::gzip_compress(json.as_bytes(), 6).map_err(|e| {
140                        TrainError::CheckpointError(format!(
141                            "Failed to gzip-compress checkpoint: {}",
142                            e
143                        ))
144                    })?;
145                std::fs::write(path, compressed).map_err(|e| {
146                    TrainError::CheckpointError(format!("Failed to write checkpoint: {}", e))
147                })?;
148            }
149            CheckpointCompression::GzipFast => {
150                // Level 1: fastest compression (matches flate2 Compression::fast())
151                let compressed =
152                    oxiarc_deflate::gzip_compress(json.as_bytes(), 1).map_err(|e| {
153                        TrainError::CheckpointError(format!(
154                            "Failed to gzip-compress checkpoint (fast): {}",
155                            e
156                        ))
157                    })?;
158                std::fs::write(path, compressed).map_err(|e| {
159                    TrainError::CheckpointError(format!("Failed to write checkpoint: {}", e))
160                })?;
161            }
162            CheckpointCompression::GzipBest => {
163                // Level 9: best compression (matches flate2 Compression::best())
164                let compressed =
165                    oxiarc_deflate::gzip_compress(json.as_bytes(), 9).map_err(|e| {
166                        TrainError::CheckpointError(format!(
167                            "Failed to gzip-compress checkpoint (best): {}",
168                            e
169                        ))
170                    })?;
171                std::fs::write(path, compressed).map_err(|e| {
172                    TrainError::CheckpointError(format!("Failed to write checkpoint: {}", e))
173                })?;
174            }
175        }
176
177        Ok(())
178    }
179
180    /// Load checkpoint from a file.
181    pub fn load(path: &PathBuf) -> TrainResult<Self> {
182        // Auto-detect compression based on file extension
183        if path.to_string_lossy().ends_with(".gz") {
184            Self::load_compressed(path)
185        } else {
186            Self::load_uncompressed(path)
187        }
188    }
189
190    /// Load uncompressed checkpoint from a file.
191    fn load_uncompressed(path: &PathBuf) -> TrainResult<Self> {
192        let json = std::fs::read_to_string(path).map_err(|e| {
193            TrainError::CheckpointError(format!("Failed to read checkpoint: {}", e))
194        })?;
195
196        let checkpoint: Self = serde_json::from_str(&json).map_err(|e| {
197            TrainError::CheckpointError(format!("Failed to deserialize checkpoint: {}", e))
198        })?;
199
200        Ok(checkpoint)
201    }
202
203    /// Load compressed checkpoint from a file.
204    pub fn load_compressed(path: &PathBuf) -> TrainResult<Self> {
205        let mut file = File::open(path).map_err(|e| {
206            TrainError::CheckpointError(format!("Failed to open checkpoint file: {}", e))
207        })?;
208
209        let mut compressed = Vec::new();
210        file.read_to_end(&mut compressed).map_err(|e| {
211            TrainError::CheckpointError(format!("Failed to read checkpoint file: {}", e))
212        })?;
213
214        let decompressed = oxiarc_deflate::gzip_decompress(&compressed).map_err(|e| {
215            TrainError::CheckpointError(format!("Failed to decompress checkpoint: {}", e))
216        })?;
217
218        let json = String::from_utf8(decompressed).map_err(|e| {
219            TrainError::CheckpointError(format!(
220                "Decompressed checkpoint is not valid UTF-8: {}",
221                e
222            ))
223        })?;
224
225        let checkpoint: Self = serde_json::from_str(&json).map_err(|e| {
226            TrainError::CheckpointError(format!("Failed to deserialize checkpoint: {}", e))
227        })?;
228
229        Ok(checkpoint)
230    }
231
232    /// Get the size of the checkpoint in bytes (estimated).
233    pub fn estimated_size(&self) -> usize {
234        // Rough estimate: parameters + optimizer_state + histories
235        let param_size: usize = self
236            .parameters
237            .values()
238            .map(|v| v.len() * std::mem::size_of::<f64>())
239            .sum();
240        let optimizer_size: usize = self
241            .optimizer_state
242            .values()
243            .map(|v| v.len() * std::mem::size_of::<f64>())
244            .sum();
245        let history_size = (self.train_loss_history.len() + self.val_loss_history.len())
246            * std::mem::size_of::<f64>();
247
248        param_size + optimizer_size + history_size
249    }
250}
251
252/// Checkpoint metadata for tracking saved checkpoints.
253#[derive(Debug, Clone, PartialEq)]
254struct CheckpointMetadata {
255    /// Epoch number.
256    epoch: usize,
257    /// Validation loss (if available).
258    val_loss: Option<f64>,
259    /// File path.
260    path: PathBuf,
261}
262
263/// Callback for model checkpointing with auto-cleanup.
264pub struct CheckpointCallback {
265    /// Directory to save checkpoints.
266    pub checkpoint_dir: PathBuf,
267    /// Frequency of checkpointing (every N epochs).
268    pub save_frequency: usize,
269    /// Whether to save only the best model.
270    pub save_best_only: bool,
271    /// Maximum number of checkpoints to keep (None = keep all).
272    pub keep_top_k: Option<usize>,
273    /// Best validation loss seen so far.
274    best_val_loss: Option<f64>,
275    /// Metadata of saved checkpoints for cleanup.
276    saved_checkpoints: Vec<CheckpointMetadata>,
277}
278
279impl CheckpointCallback {
280    /// Create a new checkpoint callback.
281    pub fn new(checkpoint_dir: PathBuf, save_frequency: usize, save_best_only: bool) -> Self {
282        Self {
283            checkpoint_dir,
284            save_frequency,
285            save_best_only,
286            keep_top_k: None,
287            best_val_loss: None,
288            saved_checkpoints: Vec::new(),
289        }
290    }
291
292    /// Create a new checkpoint callback with auto-cleanup.
293    ///
294    /// This will automatically delete old checkpoints when the number exceeds `keep_top_k`,
295    /// keeping only the checkpoints with the best (lowest) validation loss.
296    ///
297    /// # Arguments
298    /// * `checkpoint_dir` - Directory to save checkpoints
299    /// * `save_frequency` - Save every N epochs
300    /// * `save_best_only` - Only save when validation loss improves
301    /// * `keep_top_k` - Maximum number of checkpoints to keep (keeps best by validation loss)
302    ///
303    /// # Example
304    /// ```no_run
305    /// use tensorlogic_train::CheckpointCallback;
306    /// use std::path::PathBuf;
307    ///
308    /// // Keep only the top 5 best checkpoints
309    /// let callback = CheckpointCallback::with_cleanup(
310    ///     PathBuf::from("/tmp/checkpoints"),
311    ///     1,    // save every epoch
312    ///     false, // save all, not just best
313    ///     5     // keep top 5
314    /// );
315    /// ```
316    pub fn with_cleanup(
317        checkpoint_dir: PathBuf,
318        save_frequency: usize,
319        save_best_only: bool,
320        keep_top_k: usize,
321    ) -> Self {
322        Self {
323            checkpoint_dir,
324            save_frequency,
325            save_best_only,
326            keep_top_k: Some(keep_top_k),
327            best_val_loss: None,
328            saved_checkpoints: Vec::new(),
329        }
330    }
331
332    /// Get the number of saved checkpoints being tracked.
333    pub fn num_saved_checkpoints(&self) -> usize {
334        self.saved_checkpoints.len()
335    }
336
337    /// Manually cleanup checkpoints, keeping only the top-k best.
338    ///
339    /// This can be called manually to trigger cleanup if you've changed the
340    /// `keep_top_k` setting.
341    pub fn cleanup_checkpoints(&mut self) -> TrainResult<usize> {
342        let keep_top_k = match self.keep_top_k {
343            Some(k) => k,
344            None => return Ok(0), // No cleanup needed
345        };
346
347        if self.saved_checkpoints.len() <= keep_top_k {
348            return Ok(0); // Don't need to clean up yet
349        }
350
351        // Sort by validation loss (ascending - best first)
352        // For checkpoints without val_loss, prefer more recent epochs
353        self.saved_checkpoints.sort_by(|a, b| {
354            match (a.val_loss, b.val_loss) {
355                (Some(a_loss), Some(b_loss)) => a_loss
356                    .partial_cmp(&b_loss)
357                    .unwrap_or(std::cmp::Ordering::Equal),
358                (Some(_), None) => std::cmp::Ordering::Less, // Prefer checkpoints with val_loss
359                (None, Some(_)) => std::cmp::Ordering::Greater, // Prefer checkpoints with val_loss
360                (None, None) => b.epoch.cmp(&a.epoch),       // Prefer newer epochs (descending)
361            }
362        });
363
364        // Remove checkpoints beyond top-k
365        let to_remove: Vec<CheckpointMetadata> =
366            self.saved_checkpoints.drain(keep_top_k..).collect();
367
368        let mut deleted_count = 0;
369        for checkpoint in to_remove {
370            if let Err(e) = std::fs::remove_file(&checkpoint.path) {
371                eprintln!(
372                    "Warning: Failed to delete checkpoint {:?}: {}",
373                    checkpoint.path, e
374                );
375            } else {
376                deleted_count += 1;
377            }
378        }
379
380        Ok(deleted_count)
381    }
382
383    /// Save checkpoint to disk (legacy simple format).
384    fn save_checkpoint(&mut self, epoch: usize, state: &TrainingState) -> TrainResult<()> {
385        let checkpoint_path = self
386            .checkpoint_dir
387            .join(format!("checkpoint_epoch_{}.json", epoch));
388
389        // Create checkpoint data
390        let mut checkpoint = HashMap::new();
391        checkpoint.insert("epoch".to_string(), epoch as f64);
392        checkpoint.insert("train_loss".to_string(), state.train_loss);
393        if let Some(val_loss) = state.val_loss {
394            checkpoint.insert("val_loss".to_string(), val_loss);
395        }
396
397        // Save to JSON
398        let json = serde_json::to_string_pretty(&checkpoint).map_err(|e| {
399            TrainError::CheckpointError(format!("Failed to serialize checkpoint: {}", e))
400        })?;
401
402        std::fs::create_dir_all(&self.checkpoint_dir).map_err(|e| {
403            TrainError::CheckpointError(format!("Failed to create checkpoint directory: {}", e))
404        })?;
405
406        std::fs::write(&checkpoint_path, json).map_err(|e| {
407            TrainError::CheckpointError(format!("Failed to write checkpoint: {}", e))
408        })?;
409
410        // Track checkpoint metadata
411        let metadata = CheckpointMetadata {
412            epoch,
413            val_loss: state.val_loss,
414            path: checkpoint_path.clone(),
415        };
416        self.saved_checkpoints.push(metadata);
417
418        // Auto-cleanup if needed
419        if self.keep_top_k.is_some() {
420            let deleted = self.cleanup_checkpoints()?;
421            if deleted > 0 {
422                println!(
423                    "Checkpoint saved to {:?} (deleted {} old checkpoints)",
424                    checkpoint_path, deleted
425                );
426            } else {
427                println!("Checkpoint saved to {:?}", checkpoint_path);
428            }
429        } else {
430            println!("Checkpoint saved to {:?}", checkpoint_path);
431        }
432
433        Ok(())
434    }
435}
436
437impl Callback for CheckpointCallback {
438    fn on_epoch_end(&mut self, epoch: usize, state: &TrainingState) -> TrainResult<()> {
439        if !epoch.is_multiple_of(self.save_frequency) {
440            return Ok(());
441        }
442
443        if self.save_best_only {
444            if let Some(val_loss) = state.val_loss {
445                let should_save = self
446                    .best_val_loss
447                    .map(|best| val_loss < best)
448                    .unwrap_or(true);
449
450                if should_save {
451                    self.best_val_loss = Some(val_loss);
452                    self.save_checkpoint(epoch, state)?;
453                }
454            }
455        } else {
456            self.save_checkpoint(epoch, state)?;
457        }
458
459        Ok(())
460    }
461}
462
463#[cfg(test)]
464mod tests {
465    use super::*;
466    use scirs2_core::ndarray::Array2;
467    use std::env::temp_dir;
468
469    fn create_test_state() -> TrainingState {
470        TrainingState {
471            epoch: 0,
472            batch: 0,
473            train_loss: 1.0,
474            val_loss: Some(0.8),
475            batch_loss: 0.5,
476            learning_rate: 0.001,
477            metrics: HashMap::new(),
478        }
479    }
480
481    #[test]
482    fn test_checkpoint_callback() {
483        let checkpoint_dir = temp_dir().join("tensorlogic_test_checkpoints");
484        let mut callback = CheckpointCallback::new(checkpoint_dir.clone(), 1, false);
485        let state = create_test_state();
486
487        callback.on_epoch_end(0, &state).expect("unwrap");
488
489        // Verify checkpoint was created
490        let checkpoint_path = checkpoint_dir.join("checkpoint_epoch_0.json");
491        assert!(checkpoint_path.exists());
492
493        // Clean up
494        std::fs::remove_dir_all(checkpoint_dir).ok();
495    }
496
497    #[test]
498    fn test_training_checkpoint_save_load() {
499        // Create test parameters
500        let mut parameters = HashMap::new();
501        parameters.insert("weight".to_string(), Array2::from_elem((2, 3), 1.5));
502        parameters.insert("bias".to_string(), Array2::from_elem((1, 3), 0.5));
503
504        // Create test state
505        let state = TrainingState {
506            epoch: 5,
507            batch: 100,
508            train_loss: 0.75,
509            val_loss: Some(0.85),
510            batch_loss: 0.72,
511            learning_rate: 0.001,
512            metrics: HashMap::new(),
513        };
514
515        // Create optimizer state (mock)
516        let optimizer_state = {
517            let mut state = HashMap::new();
518            state.insert("momentum_weight".to_string(), vec![0.1, 0.2, 0.3]);
519            state.insert("momentum_bias".to_string(), vec![0.05]);
520            state
521        };
522
523        // Create checkpoint
524        let checkpoint = TrainingCheckpoint::new(
525            5,
526            &parameters,
527            &optimizer_state,
528            None,
529            &state,
530            &[1.0, 0.9, 0.8, 0.77, 0.75],
531            &[1.1, 0.95, 0.88, 0.87, 0.85],
532            &HashMap::new(),
533            Some(0.85),
534        );
535
536        // Save checkpoint
537        let checkpoint_path = temp_dir().join("test_training_checkpoint.json");
538        checkpoint.save(&checkpoint_path).expect("unwrap");
539
540        // Verify file exists
541        assert!(checkpoint_path.exists());
542
543        // Load checkpoint
544        let loaded = TrainingCheckpoint::load(&checkpoint_path).expect("unwrap");
545
546        // Verify data
547        assert_eq!(loaded.epoch, 5);
548        assert_eq!(loaded.train_loss, 0.75);
549        assert_eq!(loaded.val_loss, Some(0.85));
550        assert_eq!(loaded.learning_rate, 0.001);
551        assert_eq!(loaded.train_loss_history.len(), 5);
552        assert_eq!(loaded.val_loss_history.len(), 5);
553        assert_eq!(loaded.best_val_loss, Some(0.85));
554
555        // Verify parameters
556        assert_eq!(loaded.parameters.len(), 2);
557        assert!(loaded.parameters.contains_key("weight"));
558        assert!(loaded.parameters.contains_key("bias"));
559
560        // Verify optimizer state
561        assert_eq!(loaded.optimizer_state.len(), 2);
562        assert!(loaded.optimizer_state.contains_key("momentum_weight"));
563
564        // Clean up
565        std::fs::remove_file(checkpoint_path).ok();
566    }
567
568    #[test]
569    fn test_training_checkpoint_with_metrics() {
570        let mut parameters = HashMap::new();
571        parameters.insert("w".to_string(), Array2::zeros((2, 2)));
572
573        let state = create_test_state();
574        let optimizer_state = HashMap::new();
575
576        // Add metrics history
577        let mut metrics_history = HashMap::new();
578        metrics_history.insert("accuracy".to_string(), vec![0.5, 0.6, 0.7]);
579        metrics_history.insert("f1_score".to_string(), vec![0.45, 0.55, 0.65]);
580
581        let checkpoint = TrainingCheckpoint::new(
582            2,
583            &parameters,
584            &optimizer_state,
585            None,
586            &state,
587            &[1.0, 0.8, 0.6],
588            &[1.1, 0.9, 0.7],
589            &metrics_history,
590            Some(0.7),
591        );
592
593        let checkpoint_path = temp_dir().join("test_checkpoint_with_metrics.json");
594        checkpoint.save(&checkpoint_path).expect("unwrap");
595
596        let loaded = TrainingCheckpoint::load(&checkpoint_path).expect("unwrap");
597
598        // Verify metrics
599        assert_eq!(loaded.metrics_history.len(), 2);
600        assert!(loaded.metrics_history.contains_key("accuracy"));
601        assert!(loaded.metrics_history.contains_key("f1_score"));
602        assert_eq!(loaded.metrics_history["accuracy"].len(), 3);
603
604        std::fs::remove_file(checkpoint_path).ok();
605    }
606
607    #[test]
608    fn test_checkpoint_compression_gzip() {
609        let mut parameters = HashMap::new();
610        parameters.insert("weights".to_string(), Array2::from_elem((100, 100), 1.5));
611
612        let state = create_test_state();
613        let optimizer_state = HashMap::new();
614
615        let checkpoint = TrainingCheckpoint::new(
616            10,
617            &parameters,
618            &optimizer_state,
619            None,
620            &state,
621            &vec![1.0; 100],
622            &vec![0.9; 100],
623            &HashMap::new(),
624            Some(0.5),
625        );
626
627        // Save with gzip compression
628        let compressed_path = temp_dir().join("test_checkpoint_compressed.json.gz");
629        checkpoint
630            .save_with_compression(&compressed_path, CheckpointCompression::Gzip)
631            .expect("unwrap");
632
633        // Verify compressed file exists
634        assert!(compressed_path.exists());
635
636        // Load compressed checkpoint
637        let loaded = TrainingCheckpoint::load(&compressed_path).expect("unwrap");
638
639        // Verify data
640        assert_eq!(loaded.epoch, 10);
641        assert_eq!(loaded.parameters.len(), 1);
642        assert_eq!(loaded.parameters["weights"].len(), 10000); // 100x100
643
644        // Compare file sizes
645        let uncompressed_path = temp_dir().join("test_checkpoint_uncompressed.json");
646        checkpoint.save(&uncompressed_path).expect("unwrap");
647
648        let compressed_size = std::fs::metadata(&compressed_path).expect("unwrap").len();
649        let uncompressed_size = std::fs::metadata(&uncompressed_path).expect("unwrap").len();
650
651        // Compressed should be smaller
652        assert!(
653            compressed_size < uncompressed_size,
654            "Compressed size {} should be less than uncompressed size {}",
655            compressed_size,
656            uncompressed_size
657        );
658
659        // Clean up
660        std::fs::remove_file(compressed_path).ok();
661        std::fs::remove_file(uncompressed_path).ok();
662    }
663
664    #[test]
665    fn test_checkpoint_compression_fast_vs_best() {
666        let mut parameters = HashMap::new();
667        parameters.insert("weights".to_string(), Array2::from_elem((50, 50), 2.0));
668
669        let state = create_test_state();
670        let optimizer_state = HashMap::new();
671
672        let checkpoint = TrainingCheckpoint::new(
673            5,
674            &parameters,
675            &optimizer_state,
676            None,
677            &state,
678            &vec![1.0; 50],
679            &vec![0.8; 50],
680            &HashMap::new(),
681            None,
682        );
683
684        // Save with fast compression
685        let fast_path = temp_dir().join("test_checkpoint_fast.json.gz");
686        checkpoint
687            .save_with_compression(&fast_path, CheckpointCompression::GzipFast)
688            .expect("unwrap");
689
690        // Save with best compression
691        let best_path = temp_dir().join("test_checkpoint_best.json.gz");
692        checkpoint
693            .save_with_compression(&best_path, CheckpointCompression::GzipBest)
694            .expect("unwrap");
695
696        // Both should be loadable
697        let loaded_fast = TrainingCheckpoint::load(&fast_path).expect("unwrap");
698        let loaded_best = TrainingCheckpoint::load(&best_path).expect("unwrap");
699
700        assert_eq!(loaded_fast.epoch, 5);
701        assert_eq!(loaded_best.epoch, 5);
702        assert_eq!(
703            loaded_fast.parameters["weights"],
704            loaded_best.parameters["weights"]
705        );
706
707        // Clean up
708        std::fs::remove_file(fast_path).ok();
709        std::fs::remove_file(best_path).ok();
710    }
711
712    #[test]
713    fn test_checkpoint_estimated_size() {
714        let mut parameters = HashMap::new();
715        parameters.insert("w1".to_string(), Array2::from_elem((10, 10), 1.0));
716        parameters.insert("w2".to_string(), Array2::from_elem((5, 5), 1.0));
717
718        let state = create_test_state();
719        let optimizer_state = HashMap::new();
720
721        let train_loss_history: [f64; 10] = [1.0; 10];
722        let val_loss_history: [f64; 10] = [0.9; 10];
723        let checkpoint = TrainingCheckpoint::new(
724            1,
725            &parameters,
726            &optimizer_state,
727            None,
728            &state,
729            &train_loss_history,
730            &val_loss_history,
731            &HashMap::new(),
732            None,
733        );
734
735        let size = checkpoint.estimated_size();
736        // 100 + 25 = 125 parameters * 8 bytes + 20 history entries * 8 bytes
737        assert!(size > 0);
738        assert_eq!(
739            size,
740            (100 + 25) * std::mem::size_of::<f64>() + 20 * std::mem::size_of::<f64>()
741        );
742    }
743
744    #[test]
745    fn test_checkpoint_auto_detect_compression() {
746        let mut parameters = HashMap::new();
747        parameters.insert("w".to_string(), Array2::from_elem((5, 5), 1.0));
748
749        let state = create_test_state();
750
751        let checkpoint = TrainingCheckpoint::new(
752            1,
753            &parameters,
754            &HashMap::new(),
755            None,
756            &state,
757            &[1.0],
758            &[0.9],
759            &HashMap::new(),
760            None,
761        );
762
763        // Save uncompressed
764        let uncompressed_path = temp_dir().join("test_auto_detect.json");
765        checkpoint.save(&uncompressed_path).expect("unwrap");
766
767        // Save compressed
768        let compressed_path = temp_dir().join("test_auto_detect.json.gz");
769        checkpoint
770            .save_with_compression(&compressed_path, CheckpointCompression::Gzip)
771            .expect("unwrap");
772
773        // Load both using auto-detection
774        let loaded_uncompressed = TrainingCheckpoint::load(&uncompressed_path).expect("unwrap");
775        let loaded_compressed = TrainingCheckpoint::load(&compressed_path).expect("unwrap");
776
777        assert_eq!(loaded_uncompressed.epoch, loaded_compressed.epoch);
778        assert_eq!(loaded_uncompressed.parameters, loaded_compressed.parameters);
779
780        // Clean up
781        std::fs::remove_file(uncompressed_path).ok();
782        std::fs::remove_file(compressed_path).ok();
783    }
784
785    #[test]
786    fn test_checkpoint_auto_cleanup() {
787        let checkpoint_dir = temp_dir().join("tensorlogic_test_auto_cleanup");
788        std::fs::create_dir_all(&checkpoint_dir).ok();
789
790        // Create callback with keep_top_k = 3
791        let mut callback = CheckpointCallback::with_cleanup(checkpoint_dir.clone(), 1, false, 3);
792
793        // Save 5 checkpoints with different validation losses
794        let val_losses = [0.9, 0.7, 0.8, 0.6, 0.5]; // Best is 0.5, then 0.6, then 0.7
795
796        for (epoch, &val_loss) in val_losses.iter().enumerate() {
797            let mut state = create_test_state();
798            state.val_loss = Some(val_loss);
799            callback.save_checkpoint(epoch, &state).expect("unwrap");
800        }
801
802        // Should only have 3 checkpoints remaining (top 3 best)
803        assert_eq!(callback.num_saved_checkpoints(), 3);
804
805        // Verify the best 3 checkpoints exist
806        assert!(checkpoint_dir.join("checkpoint_epoch_4.json").exists()); // val_loss = 0.5
807        assert!(checkpoint_dir.join("checkpoint_epoch_3.json").exists()); // val_loss = 0.6
808        assert!(checkpoint_dir.join("checkpoint_epoch_1.json").exists()); // val_loss = 0.7
809
810        // Verify the worst 2 were deleted
811        assert!(!checkpoint_dir.join("checkpoint_epoch_0.json").exists()); // val_loss = 0.9
812        assert!(!checkpoint_dir.join("checkpoint_epoch_2.json").exists()); // val_loss = 0.8
813
814        // Clean up
815        std::fs::remove_dir_all(checkpoint_dir).ok();
816    }
817
818    #[test]
819    fn test_checkpoint_no_cleanup_when_disabled() {
820        let checkpoint_dir = temp_dir().join("tensorlogic_test_no_cleanup");
821        std::fs::create_dir_all(&checkpoint_dir).ok();
822
823        // Create callback without cleanup (keep_top_k = None)
824        let mut callback = CheckpointCallback::new(checkpoint_dir.clone(), 1, false);
825
826        // Save 5 checkpoints
827        for epoch in 0..5 {
828            let state = create_test_state();
829            callback.save_checkpoint(epoch, &state).expect("unwrap");
830        }
831
832        // All 5 checkpoints should still exist
833        for epoch in 0..5 {
834            let path = checkpoint_dir.join(format!("checkpoint_epoch_{}.json", epoch));
835            assert!(path.exists(), "Checkpoint {} should exist", epoch);
836        }
837
838        // Clean up
839        std::fs::remove_dir_all(checkpoint_dir).ok();
840    }
841
842    #[test]
843    fn test_checkpoint_manual_cleanup() {
844        let checkpoint_dir = temp_dir().join("tensorlogic_test_manual_cleanup");
845        std::fs::create_dir_all(&checkpoint_dir).ok();
846
847        // Create callback with keep_top_k = 2
848        let mut callback = CheckpointCallback::with_cleanup(checkpoint_dir.clone(), 1, false, 2);
849
850        // Save 4 checkpoints
851        let val_losses = [0.8, 0.6, 0.9, 0.5];
852        for (epoch, &val_loss) in val_losses.iter().enumerate() {
853            let mut state = create_test_state();
854            state.val_loss = Some(val_loss);
855            callback.save_checkpoint(epoch, &state).expect("unwrap");
856        }
857
858        // Should have only top 2
859        assert_eq!(callback.num_saved_checkpoints(), 2);
860
861        // Manually trigger cleanup (should do nothing since we're already at top-2)
862        let deleted = callback.cleanup_checkpoints().expect("unwrap");
863        assert_eq!(deleted, 0);
864        assert_eq!(callback.num_saved_checkpoints(), 2);
865
866        // Clean up
867        std::fs::remove_dir_all(checkpoint_dir).ok();
868    }
869
870    #[test]
871    fn test_checkpoint_cleanup_without_val_loss() {
872        let checkpoint_dir = temp_dir().join("tensorlogic_test_cleanup_no_val_loss");
873        std::fs::create_dir_all(&checkpoint_dir).ok();
874
875        // Create callback with keep_top_k = 2
876        let mut callback = CheckpointCallback::with_cleanup(checkpoint_dir.clone(), 1, false, 2);
877
878        // Save 4 checkpoints without validation loss
879        for epoch in 0..4 {
880            let mut state = create_test_state();
881            state.val_loss = None; // No validation loss
882            callback.save_checkpoint(epoch, &state).expect("unwrap");
883        }
884
885        // Should keep top 2 (most recent by epoch)
886        assert_eq!(callback.num_saved_checkpoints(), 2);
887
888        // Verify most recent 2 epochs exist
889        assert!(checkpoint_dir.join("checkpoint_epoch_3.json").exists());
890        assert!(checkpoint_dir.join("checkpoint_epoch_2.json").exists());
891
892        // Clean up
893        std::fs::remove_dir_all(checkpoint_dir).ok();
894    }
895
896    #[test]
897    fn test_checkpoint_with_save_best_only_and_cleanup() {
898        let checkpoint_dir = temp_dir().join("tensorlogic_test_best_and_cleanup");
899        std::fs::create_dir_all(&checkpoint_dir).ok();
900
901        // Create callback with both save_best_only and keep_top_k
902        let mut callback = CheckpointCallback::with_cleanup(checkpoint_dir.clone(), 1, true, 2);
903
904        // Try to save checkpoints with improving and non-improving losses
905        let val_losses = [0.9, 0.7, 0.8, 0.6]; // 0.9 -> 0.7 (save), 0.8 (skip), 0.6 (save)
906
907        for (epoch, &val_loss) in val_losses.iter().enumerate() {
908            let mut state = create_test_state();
909            state.val_loss = Some(val_loss);
910            callback.on_epoch_end(epoch, &state).expect("unwrap");
911        }
912
913        // Should only have saved the improving checkpoints (0.9, 0.7, 0.6), then cleaned up to top-2
914        assert!(callback.num_saved_checkpoints() <= 2);
915
916        // Clean up
917        std::fs::remove_dir_all(checkpoint_dir).ok();
918    }
919}