1use crate::callbacks::core::Callback;
4use crate::{TrainError, TrainResult, TrainingState};
5use std::collections::HashMap;
6use std::fs::File;
7use std::io::Read;
8use std::path::PathBuf;
9
10#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
12pub enum CheckpointCompression {
13 #[default]
15 None,
16 Gzip,
18 GzipFast,
20 GzipBest,
22}
23
24#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
29pub struct TrainingCheckpoint {
30 pub epoch: usize,
32 pub parameters: HashMap<String, Vec<f64>>,
34 pub optimizer_state: HashMap<String, Vec<f64>>,
36 pub scheduler_state: Option<HashMap<String, f64>>,
38 pub train_loss: f64,
40 pub val_loss: Option<f64>,
42 pub train_loss_history: Vec<f64>,
44 pub val_loss_history: Vec<f64>,
46 pub metrics_history: HashMap<String, Vec<f64>>,
48 pub learning_rate: f64,
50 pub best_val_loss: Option<f64>,
52}
53
54impl TrainingCheckpoint {
55 #[allow(clippy::too_many_arguments)]
57 pub fn new(
58 epoch: usize,
59 parameters: &HashMap<String, scirs2_core::ndarray::Array<f64, scirs2_core::ndarray::Ix2>>,
60 optimizer_state: &HashMap<String, Vec<f64>>,
61 scheduler_state: Option<HashMap<String, f64>>,
62 state: &TrainingState,
63 train_loss_history: &[f64],
64 val_loss_history: &[f64],
65 metrics_history: &HashMap<String, Vec<f64>>,
66 best_val_loss: Option<f64>,
67 ) -> Self {
68 let parameters = parameters
70 .iter()
71 .map(|(name, param)| (name.clone(), param.iter().copied().collect()))
72 .collect();
73
74 Self {
75 epoch,
76 parameters,
77 optimizer_state: optimizer_state.clone(),
78 scheduler_state,
79 train_loss: state.train_loss,
80 val_loss: state.val_loss,
81 train_loss_history: train_loss_history.to_vec(),
82 val_loss_history: val_loss_history.to_vec(),
83 metrics_history: metrics_history.clone(),
84 learning_rate: state.learning_rate,
85 best_val_loss,
86 }
87 }
88
89 pub fn save(&self, path: &PathBuf) -> TrainResult<()> {
91 self.save_with_compression(path, CheckpointCompression::None)
92 }
93
94 pub fn save_with_compression(
116 &self,
117 path: &PathBuf,
118 compression: CheckpointCompression,
119 ) -> TrainResult<()> {
120 let json = serde_json::to_string_pretty(self).map_err(|e| {
121 TrainError::CheckpointError(format!("Failed to serialize checkpoint: {}", e))
122 })?;
123
124 if let Some(parent) = path.parent() {
125 std::fs::create_dir_all(parent).map_err(|e| {
126 TrainError::CheckpointError(format!("Failed to create checkpoint directory: {}", e))
127 })?;
128 }
129
130 match compression {
131 CheckpointCompression::None => {
132 std::fs::write(path, json).map_err(|e| {
133 TrainError::CheckpointError(format!("Failed to write checkpoint: {}", e))
134 })?;
135 }
136 CheckpointCompression::Gzip => {
137 let compressed =
139 oxiarc_deflate::gzip_compress(json.as_bytes(), 6).map_err(|e| {
140 TrainError::CheckpointError(format!(
141 "Failed to gzip-compress checkpoint: {}",
142 e
143 ))
144 })?;
145 std::fs::write(path, compressed).map_err(|e| {
146 TrainError::CheckpointError(format!("Failed to write checkpoint: {}", e))
147 })?;
148 }
149 CheckpointCompression::GzipFast => {
150 let compressed =
152 oxiarc_deflate::gzip_compress(json.as_bytes(), 1).map_err(|e| {
153 TrainError::CheckpointError(format!(
154 "Failed to gzip-compress checkpoint (fast): {}",
155 e
156 ))
157 })?;
158 std::fs::write(path, compressed).map_err(|e| {
159 TrainError::CheckpointError(format!("Failed to write checkpoint: {}", e))
160 })?;
161 }
162 CheckpointCompression::GzipBest => {
163 let compressed =
165 oxiarc_deflate::gzip_compress(json.as_bytes(), 9).map_err(|e| {
166 TrainError::CheckpointError(format!(
167 "Failed to gzip-compress checkpoint (best): {}",
168 e
169 ))
170 })?;
171 std::fs::write(path, compressed).map_err(|e| {
172 TrainError::CheckpointError(format!("Failed to write checkpoint: {}", e))
173 })?;
174 }
175 }
176
177 Ok(())
178 }
179
180 pub fn load(path: &PathBuf) -> TrainResult<Self> {
182 if path.to_string_lossy().ends_with(".gz") {
184 Self::load_compressed(path)
185 } else {
186 Self::load_uncompressed(path)
187 }
188 }
189
190 fn load_uncompressed(path: &PathBuf) -> TrainResult<Self> {
192 let json = std::fs::read_to_string(path).map_err(|e| {
193 TrainError::CheckpointError(format!("Failed to read checkpoint: {}", e))
194 })?;
195
196 let checkpoint: Self = serde_json::from_str(&json).map_err(|e| {
197 TrainError::CheckpointError(format!("Failed to deserialize checkpoint: {}", e))
198 })?;
199
200 Ok(checkpoint)
201 }
202
203 pub fn load_compressed(path: &PathBuf) -> TrainResult<Self> {
205 let mut file = File::open(path).map_err(|e| {
206 TrainError::CheckpointError(format!("Failed to open checkpoint file: {}", e))
207 })?;
208
209 let mut compressed = Vec::new();
210 file.read_to_end(&mut compressed).map_err(|e| {
211 TrainError::CheckpointError(format!("Failed to read checkpoint file: {}", e))
212 })?;
213
214 let decompressed = oxiarc_deflate::gzip_decompress(&compressed).map_err(|e| {
215 TrainError::CheckpointError(format!("Failed to decompress checkpoint: {}", e))
216 })?;
217
218 let json = String::from_utf8(decompressed).map_err(|e| {
219 TrainError::CheckpointError(format!(
220 "Decompressed checkpoint is not valid UTF-8: {}",
221 e
222 ))
223 })?;
224
225 let checkpoint: Self = serde_json::from_str(&json).map_err(|e| {
226 TrainError::CheckpointError(format!("Failed to deserialize checkpoint: {}", e))
227 })?;
228
229 Ok(checkpoint)
230 }
231
232 pub fn estimated_size(&self) -> usize {
234 let param_size: usize = self
236 .parameters
237 .values()
238 .map(|v| v.len() * std::mem::size_of::<f64>())
239 .sum();
240 let optimizer_size: usize = self
241 .optimizer_state
242 .values()
243 .map(|v| v.len() * std::mem::size_of::<f64>())
244 .sum();
245 let history_size = (self.train_loss_history.len() + self.val_loss_history.len())
246 * std::mem::size_of::<f64>();
247
248 param_size + optimizer_size + history_size
249 }
250}
251
252#[derive(Debug, Clone, PartialEq)]
254struct CheckpointMetadata {
255 epoch: usize,
257 val_loss: Option<f64>,
259 path: PathBuf,
261}
262
263pub struct CheckpointCallback {
265 pub checkpoint_dir: PathBuf,
267 pub save_frequency: usize,
269 pub save_best_only: bool,
271 pub keep_top_k: Option<usize>,
273 best_val_loss: Option<f64>,
275 saved_checkpoints: Vec<CheckpointMetadata>,
277}
278
279impl CheckpointCallback {
280 pub fn new(checkpoint_dir: PathBuf, save_frequency: usize, save_best_only: bool) -> Self {
282 Self {
283 checkpoint_dir,
284 save_frequency,
285 save_best_only,
286 keep_top_k: None,
287 best_val_loss: None,
288 saved_checkpoints: Vec::new(),
289 }
290 }
291
292 pub fn with_cleanup(
317 checkpoint_dir: PathBuf,
318 save_frequency: usize,
319 save_best_only: bool,
320 keep_top_k: usize,
321 ) -> Self {
322 Self {
323 checkpoint_dir,
324 save_frequency,
325 save_best_only,
326 keep_top_k: Some(keep_top_k),
327 best_val_loss: None,
328 saved_checkpoints: Vec::new(),
329 }
330 }
331
332 pub fn num_saved_checkpoints(&self) -> usize {
334 self.saved_checkpoints.len()
335 }
336
337 pub fn cleanup_checkpoints(&mut self) -> TrainResult<usize> {
342 let keep_top_k = match self.keep_top_k {
343 Some(k) => k,
344 None => return Ok(0), };
346
347 if self.saved_checkpoints.len() <= keep_top_k {
348 return Ok(0); }
350
351 self.saved_checkpoints.sort_by(|a, b| {
354 match (a.val_loss, b.val_loss) {
355 (Some(a_loss), Some(b_loss)) => a_loss
356 .partial_cmp(&b_loss)
357 .unwrap_or(std::cmp::Ordering::Equal),
358 (Some(_), None) => std::cmp::Ordering::Less, (None, Some(_)) => std::cmp::Ordering::Greater, (None, None) => b.epoch.cmp(&a.epoch), }
362 });
363
364 let to_remove: Vec<CheckpointMetadata> =
366 self.saved_checkpoints.drain(keep_top_k..).collect();
367
368 let mut deleted_count = 0;
369 for checkpoint in to_remove {
370 if let Err(e) = std::fs::remove_file(&checkpoint.path) {
371 eprintln!(
372 "Warning: Failed to delete checkpoint {:?}: {}",
373 checkpoint.path, e
374 );
375 } else {
376 deleted_count += 1;
377 }
378 }
379
380 Ok(deleted_count)
381 }
382
383 fn save_checkpoint(&mut self, epoch: usize, state: &TrainingState) -> TrainResult<()> {
385 let checkpoint_path = self
386 .checkpoint_dir
387 .join(format!("checkpoint_epoch_{}.json", epoch));
388
389 let mut checkpoint = HashMap::new();
391 checkpoint.insert("epoch".to_string(), epoch as f64);
392 checkpoint.insert("train_loss".to_string(), state.train_loss);
393 if let Some(val_loss) = state.val_loss {
394 checkpoint.insert("val_loss".to_string(), val_loss);
395 }
396
397 let json = serde_json::to_string_pretty(&checkpoint).map_err(|e| {
399 TrainError::CheckpointError(format!("Failed to serialize checkpoint: {}", e))
400 })?;
401
402 std::fs::create_dir_all(&self.checkpoint_dir).map_err(|e| {
403 TrainError::CheckpointError(format!("Failed to create checkpoint directory: {}", e))
404 })?;
405
406 std::fs::write(&checkpoint_path, json).map_err(|e| {
407 TrainError::CheckpointError(format!("Failed to write checkpoint: {}", e))
408 })?;
409
410 let metadata = CheckpointMetadata {
412 epoch,
413 val_loss: state.val_loss,
414 path: checkpoint_path.clone(),
415 };
416 self.saved_checkpoints.push(metadata);
417
418 if self.keep_top_k.is_some() {
420 let deleted = self.cleanup_checkpoints()?;
421 if deleted > 0 {
422 println!(
423 "Checkpoint saved to {:?} (deleted {} old checkpoints)",
424 checkpoint_path, deleted
425 );
426 } else {
427 println!("Checkpoint saved to {:?}", checkpoint_path);
428 }
429 } else {
430 println!("Checkpoint saved to {:?}", checkpoint_path);
431 }
432
433 Ok(())
434 }
435}
436
437impl Callback for CheckpointCallback {
438 fn on_epoch_end(&mut self, epoch: usize, state: &TrainingState) -> TrainResult<()> {
439 if !epoch.is_multiple_of(self.save_frequency) {
440 return Ok(());
441 }
442
443 if self.save_best_only {
444 if let Some(val_loss) = state.val_loss {
445 let should_save = self
446 .best_val_loss
447 .map(|best| val_loss < best)
448 .unwrap_or(true);
449
450 if should_save {
451 self.best_val_loss = Some(val_loss);
452 self.save_checkpoint(epoch, state)?;
453 }
454 }
455 } else {
456 self.save_checkpoint(epoch, state)?;
457 }
458
459 Ok(())
460 }
461}
462
463#[cfg(test)]
464mod tests {
465 use super::*;
466 use scirs2_core::ndarray::Array2;
467 use std::env::temp_dir;
468
469 fn create_test_state() -> TrainingState {
470 TrainingState {
471 epoch: 0,
472 batch: 0,
473 train_loss: 1.0,
474 val_loss: Some(0.8),
475 batch_loss: 0.5,
476 learning_rate: 0.001,
477 metrics: HashMap::new(),
478 }
479 }
480
481 #[test]
482 fn test_checkpoint_callback() {
483 let checkpoint_dir = temp_dir().join("tensorlogic_test_checkpoints");
484 let mut callback = CheckpointCallback::new(checkpoint_dir.clone(), 1, false);
485 let state = create_test_state();
486
487 callback.on_epoch_end(0, &state).expect("unwrap");
488
489 let checkpoint_path = checkpoint_dir.join("checkpoint_epoch_0.json");
491 assert!(checkpoint_path.exists());
492
493 std::fs::remove_dir_all(checkpoint_dir).ok();
495 }
496
497 #[test]
498 fn test_training_checkpoint_save_load() {
499 let mut parameters = HashMap::new();
501 parameters.insert("weight".to_string(), Array2::from_elem((2, 3), 1.5));
502 parameters.insert("bias".to_string(), Array2::from_elem((1, 3), 0.5));
503
504 let state = TrainingState {
506 epoch: 5,
507 batch: 100,
508 train_loss: 0.75,
509 val_loss: Some(0.85),
510 batch_loss: 0.72,
511 learning_rate: 0.001,
512 metrics: HashMap::new(),
513 };
514
515 let optimizer_state = {
517 let mut state = HashMap::new();
518 state.insert("momentum_weight".to_string(), vec![0.1, 0.2, 0.3]);
519 state.insert("momentum_bias".to_string(), vec![0.05]);
520 state
521 };
522
523 let checkpoint = TrainingCheckpoint::new(
525 5,
526 ¶meters,
527 &optimizer_state,
528 None,
529 &state,
530 &[1.0, 0.9, 0.8, 0.77, 0.75],
531 &[1.1, 0.95, 0.88, 0.87, 0.85],
532 &HashMap::new(),
533 Some(0.85),
534 );
535
536 let checkpoint_path = temp_dir().join("test_training_checkpoint.json");
538 checkpoint.save(&checkpoint_path).expect("unwrap");
539
540 assert!(checkpoint_path.exists());
542
543 let loaded = TrainingCheckpoint::load(&checkpoint_path).expect("unwrap");
545
546 assert_eq!(loaded.epoch, 5);
548 assert_eq!(loaded.train_loss, 0.75);
549 assert_eq!(loaded.val_loss, Some(0.85));
550 assert_eq!(loaded.learning_rate, 0.001);
551 assert_eq!(loaded.train_loss_history.len(), 5);
552 assert_eq!(loaded.val_loss_history.len(), 5);
553 assert_eq!(loaded.best_val_loss, Some(0.85));
554
555 assert_eq!(loaded.parameters.len(), 2);
557 assert!(loaded.parameters.contains_key("weight"));
558 assert!(loaded.parameters.contains_key("bias"));
559
560 assert_eq!(loaded.optimizer_state.len(), 2);
562 assert!(loaded.optimizer_state.contains_key("momentum_weight"));
563
564 std::fs::remove_file(checkpoint_path).ok();
566 }
567
568 #[test]
569 fn test_training_checkpoint_with_metrics() {
570 let mut parameters = HashMap::new();
571 parameters.insert("w".to_string(), Array2::zeros((2, 2)));
572
573 let state = create_test_state();
574 let optimizer_state = HashMap::new();
575
576 let mut metrics_history = HashMap::new();
578 metrics_history.insert("accuracy".to_string(), vec![0.5, 0.6, 0.7]);
579 metrics_history.insert("f1_score".to_string(), vec![0.45, 0.55, 0.65]);
580
581 let checkpoint = TrainingCheckpoint::new(
582 2,
583 ¶meters,
584 &optimizer_state,
585 None,
586 &state,
587 &[1.0, 0.8, 0.6],
588 &[1.1, 0.9, 0.7],
589 &metrics_history,
590 Some(0.7),
591 );
592
593 let checkpoint_path = temp_dir().join("test_checkpoint_with_metrics.json");
594 checkpoint.save(&checkpoint_path).expect("unwrap");
595
596 let loaded = TrainingCheckpoint::load(&checkpoint_path).expect("unwrap");
597
598 assert_eq!(loaded.metrics_history.len(), 2);
600 assert!(loaded.metrics_history.contains_key("accuracy"));
601 assert!(loaded.metrics_history.contains_key("f1_score"));
602 assert_eq!(loaded.metrics_history["accuracy"].len(), 3);
603
604 std::fs::remove_file(checkpoint_path).ok();
605 }
606
607 #[test]
608 fn test_checkpoint_compression_gzip() {
609 let mut parameters = HashMap::new();
610 parameters.insert("weights".to_string(), Array2::from_elem((100, 100), 1.5));
611
612 let state = create_test_state();
613 let optimizer_state = HashMap::new();
614
615 let checkpoint = TrainingCheckpoint::new(
616 10,
617 ¶meters,
618 &optimizer_state,
619 None,
620 &state,
621 &vec![1.0; 100],
622 &vec![0.9; 100],
623 &HashMap::new(),
624 Some(0.5),
625 );
626
627 let compressed_path = temp_dir().join("test_checkpoint_compressed.json.gz");
629 checkpoint
630 .save_with_compression(&compressed_path, CheckpointCompression::Gzip)
631 .expect("unwrap");
632
633 assert!(compressed_path.exists());
635
636 let loaded = TrainingCheckpoint::load(&compressed_path).expect("unwrap");
638
639 assert_eq!(loaded.epoch, 10);
641 assert_eq!(loaded.parameters.len(), 1);
642 assert_eq!(loaded.parameters["weights"].len(), 10000); let uncompressed_path = temp_dir().join("test_checkpoint_uncompressed.json");
646 checkpoint.save(&uncompressed_path).expect("unwrap");
647
648 let compressed_size = std::fs::metadata(&compressed_path).expect("unwrap").len();
649 let uncompressed_size = std::fs::metadata(&uncompressed_path).expect("unwrap").len();
650
651 assert!(
653 compressed_size < uncompressed_size,
654 "Compressed size {} should be less than uncompressed size {}",
655 compressed_size,
656 uncompressed_size
657 );
658
659 std::fs::remove_file(compressed_path).ok();
661 std::fs::remove_file(uncompressed_path).ok();
662 }
663
664 #[test]
665 fn test_checkpoint_compression_fast_vs_best() {
666 let mut parameters = HashMap::new();
667 parameters.insert("weights".to_string(), Array2::from_elem((50, 50), 2.0));
668
669 let state = create_test_state();
670 let optimizer_state = HashMap::new();
671
672 let checkpoint = TrainingCheckpoint::new(
673 5,
674 ¶meters,
675 &optimizer_state,
676 None,
677 &state,
678 &vec![1.0; 50],
679 &vec![0.8; 50],
680 &HashMap::new(),
681 None,
682 );
683
684 let fast_path = temp_dir().join("test_checkpoint_fast.json.gz");
686 checkpoint
687 .save_with_compression(&fast_path, CheckpointCompression::GzipFast)
688 .expect("unwrap");
689
690 let best_path = temp_dir().join("test_checkpoint_best.json.gz");
692 checkpoint
693 .save_with_compression(&best_path, CheckpointCompression::GzipBest)
694 .expect("unwrap");
695
696 let loaded_fast = TrainingCheckpoint::load(&fast_path).expect("unwrap");
698 let loaded_best = TrainingCheckpoint::load(&best_path).expect("unwrap");
699
700 assert_eq!(loaded_fast.epoch, 5);
701 assert_eq!(loaded_best.epoch, 5);
702 assert_eq!(
703 loaded_fast.parameters["weights"],
704 loaded_best.parameters["weights"]
705 );
706
707 std::fs::remove_file(fast_path).ok();
709 std::fs::remove_file(best_path).ok();
710 }
711
712 #[test]
713 fn test_checkpoint_estimated_size() {
714 let mut parameters = HashMap::new();
715 parameters.insert("w1".to_string(), Array2::from_elem((10, 10), 1.0));
716 parameters.insert("w2".to_string(), Array2::from_elem((5, 5), 1.0));
717
718 let state = create_test_state();
719 let optimizer_state = HashMap::new();
720
721 let train_loss_history: [f64; 10] = [1.0; 10];
722 let val_loss_history: [f64; 10] = [0.9; 10];
723 let checkpoint = TrainingCheckpoint::new(
724 1,
725 ¶meters,
726 &optimizer_state,
727 None,
728 &state,
729 &train_loss_history,
730 &val_loss_history,
731 &HashMap::new(),
732 None,
733 );
734
735 let size = checkpoint.estimated_size();
736 assert!(size > 0);
738 assert_eq!(
739 size,
740 (100 + 25) * std::mem::size_of::<f64>() + 20 * std::mem::size_of::<f64>()
741 );
742 }
743
744 #[test]
745 fn test_checkpoint_auto_detect_compression() {
746 let mut parameters = HashMap::new();
747 parameters.insert("w".to_string(), Array2::from_elem((5, 5), 1.0));
748
749 let state = create_test_state();
750
751 let checkpoint = TrainingCheckpoint::new(
752 1,
753 ¶meters,
754 &HashMap::new(),
755 None,
756 &state,
757 &[1.0],
758 &[0.9],
759 &HashMap::new(),
760 None,
761 );
762
763 let uncompressed_path = temp_dir().join("test_auto_detect.json");
765 checkpoint.save(&uncompressed_path).expect("unwrap");
766
767 let compressed_path = temp_dir().join("test_auto_detect.json.gz");
769 checkpoint
770 .save_with_compression(&compressed_path, CheckpointCompression::Gzip)
771 .expect("unwrap");
772
773 let loaded_uncompressed = TrainingCheckpoint::load(&uncompressed_path).expect("unwrap");
775 let loaded_compressed = TrainingCheckpoint::load(&compressed_path).expect("unwrap");
776
777 assert_eq!(loaded_uncompressed.epoch, loaded_compressed.epoch);
778 assert_eq!(loaded_uncompressed.parameters, loaded_compressed.parameters);
779
780 std::fs::remove_file(uncompressed_path).ok();
782 std::fs::remove_file(compressed_path).ok();
783 }
784
785 #[test]
786 fn test_checkpoint_auto_cleanup() {
787 let checkpoint_dir = temp_dir().join("tensorlogic_test_auto_cleanup");
788 std::fs::create_dir_all(&checkpoint_dir).ok();
789
790 let mut callback = CheckpointCallback::with_cleanup(checkpoint_dir.clone(), 1, false, 3);
792
793 let val_losses = [0.9, 0.7, 0.8, 0.6, 0.5]; for (epoch, &val_loss) in val_losses.iter().enumerate() {
797 let mut state = create_test_state();
798 state.val_loss = Some(val_loss);
799 callback.save_checkpoint(epoch, &state).expect("unwrap");
800 }
801
802 assert_eq!(callback.num_saved_checkpoints(), 3);
804
805 assert!(checkpoint_dir.join("checkpoint_epoch_4.json").exists()); assert!(checkpoint_dir.join("checkpoint_epoch_3.json").exists()); assert!(checkpoint_dir.join("checkpoint_epoch_1.json").exists()); assert!(!checkpoint_dir.join("checkpoint_epoch_0.json").exists()); assert!(!checkpoint_dir.join("checkpoint_epoch_2.json").exists()); std::fs::remove_dir_all(checkpoint_dir).ok();
816 }
817
818 #[test]
819 fn test_checkpoint_no_cleanup_when_disabled() {
820 let checkpoint_dir = temp_dir().join("tensorlogic_test_no_cleanup");
821 std::fs::create_dir_all(&checkpoint_dir).ok();
822
823 let mut callback = CheckpointCallback::new(checkpoint_dir.clone(), 1, false);
825
826 for epoch in 0..5 {
828 let state = create_test_state();
829 callback.save_checkpoint(epoch, &state).expect("unwrap");
830 }
831
832 for epoch in 0..5 {
834 let path = checkpoint_dir.join(format!("checkpoint_epoch_{}.json", epoch));
835 assert!(path.exists(), "Checkpoint {} should exist", epoch);
836 }
837
838 std::fs::remove_dir_all(checkpoint_dir).ok();
840 }
841
842 #[test]
843 fn test_checkpoint_manual_cleanup() {
844 let checkpoint_dir = temp_dir().join("tensorlogic_test_manual_cleanup");
845 std::fs::create_dir_all(&checkpoint_dir).ok();
846
847 let mut callback = CheckpointCallback::with_cleanup(checkpoint_dir.clone(), 1, false, 2);
849
850 let val_losses = [0.8, 0.6, 0.9, 0.5];
852 for (epoch, &val_loss) in val_losses.iter().enumerate() {
853 let mut state = create_test_state();
854 state.val_loss = Some(val_loss);
855 callback.save_checkpoint(epoch, &state).expect("unwrap");
856 }
857
858 assert_eq!(callback.num_saved_checkpoints(), 2);
860
861 let deleted = callback.cleanup_checkpoints().expect("unwrap");
863 assert_eq!(deleted, 0);
864 assert_eq!(callback.num_saved_checkpoints(), 2);
865
866 std::fs::remove_dir_all(checkpoint_dir).ok();
868 }
869
870 #[test]
871 fn test_checkpoint_cleanup_without_val_loss() {
872 let checkpoint_dir = temp_dir().join("tensorlogic_test_cleanup_no_val_loss");
873 std::fs::create_dir_all(&checkpoint_dir).ok();
874
875 let mut callback = CheckpointCallback::with_cleanup(checkpoint_dir.clone(), 1, false, 2);
877
878 for epoch in 0..4 {
880 let mut state = create_test_state();
881 state.val_loss = None; callback.save_checkpoint(epoch, &state).expect("unwrap");
883 }
884
885 assert_eq!(callback.num_saved_checkpoints(), 2);
887
888 assert!(checkpoint_dir.join("checkpoint_epoch_3.json").exists());
890 assert!(checkpoint_dir.join("checkpoint_epoch_2.json").exists());
891
892 std::fs::remove_dir_all(checkpoint_dir).ok();
894 }
895
896 #[test]
897 fn test_checkpoint_with_save_best_only_and_cleanup() {
898 let checkpoint_dir = temp_dir().join("tensorlogic_test_best_and_cleanup");
899 std::fs::create_dir_all(&checkpoint_dir).ok();
900
901 let mut callback = CheckpointCallback::with_cleanup(checkpoint_dir.clone(), 1, true, 2);
903
904 let val_losses = [0.9, 0.7, 0.8, 0.6]; for (epoch, &val_loss) in val_losses.iter().enumerate() {
908 let mut state = create_test_state();
909 state.val_loss = Some(val_loss);
910 callback.on_epoch_end(epoch, &state).expect("unwrap");
911 }
912
913 assert!(callback.num_saved_checkpoints() <= 2);
915
916 std::fs::remove_dir_all(checkpoint_dir).ok();
918 }
919}