1use crate::callbacks::core::Callback;
4use crate::{TrainError, TrainResult, TrainingState};
5use flate2::read::GzDecoder;
6use flate2::write::GzEncoder;
7use flate2::Compression;
8use std::collections::HashMap;
9use std::fs::File;
10use std::io::{Read, Write};
11use std::path::PathBuf;
12
13#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
15pub enum CheckpointCompression {
16 #[default]
18 None,
19 Gzip,
21 GzipFast,
23 GzipBest,
25}
26
27#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
32pub struct TrainingCheckpoint {
33 pub epoch: usize,
35 pub parameters: HashMap<String, Vec<f64>>,
37 pub optimizer_state: HashMap<String, Vec<f64>>,
39 pub scheduler_state: Option<HashMap<String, f64>>,
41 pub train_loss: f64,
43 pub val_loss: Option<f64>,
45 pub train_loss_history: Vec<f64>,
47 pub val_loss_history: Vec<f64>,
49 pub metrics_history: HashMap<String, Vec<f64>>,
51 pub learning_rate: f64,
53 pub best_val_loss: Option<f64>,
55}
56
57impl TrainingCheckpoint {
58 #[allow(clippy::too_many_arguments)]
60 pub fn new(
61 epoch: usize,
62 parameters: &HashMap<String, scirs2_core::ndarray::Array<f64, scirs2_core::ndarray::Ix2>>,
63 optimizer_state: &HashMap<String, Vec<f64>>,
64 scheduler_state: Option<HashMap<String, f64>>,
65 state: &TrainingState,
66 train_loss_history: &[f64],
67 val_loss_history: &[f64],
68 metrics_history: &HashMap<String, Vec<f64>>,
69 best_val_loss: Option<f64>,
70 ) -> Self {
71 let parameters = parameters
73 .iter()
74 .map(|(name, param)| (name.clone(), param.iter().copied().collect()))
75 .collect();
76
77 Self {
78 epoch,
79 parameters,
80 optimizer_state: optimizer_state.clone(),
81 scheduler_state,
82 train_loss: state.train_loss,
83 val_loss: state.val_loss,
84 train_loss_history: train_loss_history.to_vec(),
85 val_loss_history: val_loss_history.to_vec(),
86 metrics_history: metrics_history.clone(),
87 learning_rate: state.learning_rate,
88 best_val_loss,
89 }
90 }
91
92 pub fn save(&self, path: &PathBuf) -> TrainResult<()> {
94 self.save_with_compression(path, CheckpointCompression::None)
95 }
96
97 pub fn save_with_compression(
119 &self,
120 path: &PathBuf,
121 compression: CheckpointCompression,
122 ) -> TrainResult<()> {
123 let json = serde_json::to_string_pretty(self).map_err(|e| {
124 TrainError::CheckpointError(format!("Failed to serialize checkpoint: {}", e))
125 })?;
126
127 if let Some(parent) = path.parent() {
128 std::fs::create_dir_all(parent).map_err(|e| {
129 TrainError::CheckpointError(format!("Failed to create checkpoint directory: {}", e))
130 })?;
131 }
132
133 match compression {
134 CheckpointCompression::None => {
135 std::fs::write(path, json).map_err(|e| {
136 TrainError::CheckpointError(format!("Failed to write checkpoint: {}", e))
137 })?;
138 }
139 CheckpointCompression::Gzip => {
140 let file = File::create(path).map_err(|e| {
141 TrainError::CheckpointError(format!("Failed to create checkpoint file: {}", e))
142 })?;
143 let mut encoder = GzEncoder::new(file, Compression::default());
144 encoder.write_all(json.as_bytes()).map_err(|e| {
145 TrainError::CheckpointError(format!("Failed to compress checkpoint: {}", e))
146 })?;
147 encoder.finish().map_err(|e| {
148 TrainError::CheckpointError(format!("Failed to finish compression: {}", e))
149 })?;
150 }
151 CheckpointCompression::GzipFast => {
152 let file = File::create(path).map_err(|e| {
153 TrainError::CheckpointError(format!("Failed to create checkpoint file: {}", e))
154 })?;
155 let mut encoder = GzEncoder::new(file, Compression::fast());
156 encoder.write_all(json.as_bytes()).map_err(|e| {
157 TrainError::CheckpointError(format!("Failed to compress checkpoint: {}", e))
158 })?;
159 encoder.finish().map_err(|e| {
160 TrainError::CheckpointError(format!("Failed to finish compression: {}", e))
161 })?;
162 }
163 CheckpointCompression::GzipBest => {
164 let file = File::create(path).map_err(|e| {
165 TrainError::CheckpointError(format!("Failed to create checkpoint file: {}", e))
166 })?;
167 let mut encoder = GzEncoder::new(file, Compression::best());
168 encoder.write_all(json.as_bytes()).map_err(|e| {
169 TrainError::CheckpointError(format!("Failed to compress checkpoint: {}", e))
170 })?;
171 encoder.finish().map_err(|e| {
172 TrainError::CheckpointError(format!("Failed to finish compression: {}", e))
173 })?;
174 }
175 }
176
177 Ok(())
178 }
179
180 pub fn load(path: &PathBuf) -> TrainResult<Self> {
182 if path.to_string_lossy().ends_with(".gz") {
184 Self::load_compressed(path)
185 } else {
186 Self::load_uncompressed(path)
187 }
188 }
189
190 fn load_uncompressed(path: &PathBuf) -> TrainResult<Self> {
192 let json = std::fs::read_to_string(path).map_err(|e| {
193 TrainError::CheckpointError(format!("Failed to read checkpoint: {}", e))
194 })?;
195
196 let checkpoint: Self = serde_json::from_str(&json).map_err(|e| {
197 TrainError::CheckpointError(format!("Failed to deserialize checkpoint: {}", e))
198 })?;
199
200 Ok(checkpoint)
201 }
202
203 pub fn load_compressed(path: &PathBuf) -> TrainResult<Self> {
205 let file = File::open(path).map_err(|e| {
206 TrainError::CheckpointError(format!("Failed to open checkpoint file: {}", e))
207 })?;
208
209 let mut decoder = GzDecoder::new(file);
210 let mut json = String::new();
211 decoder.read_to_string(&mut json).map_err(|e| {
212 TrainError::CheckpointError(format!("Failed to decompress checkpoint: {}", e))
213 })?;
214
215 let checkpoint: Self = serde_json::from_str(&json).map_err(|e| {
216 TrainError::CheckpointError(format!("Failed to deserialize checkpoint: {}", e))
217 })?;
218
219 Ok(checkpoint)
220 }
221
222 pub fn estimated_size(&self) -> usize {
224 let param_size: usize = self
226 .parameters
227 .values()
228 .map(|v| v.len() * std::mem::size_of::<f64>())
229 .sum();
230 let optimizer_size: usize = self
231 .optimizer_state
232 .values()
233 .map(|v| v.len() * std::mem::size_of::<f64>())
234 .sum();
235 let history_size = (self.train_loss_history.len() + self.val_loss_history.len())
236 * std::mem::size_of::<f64>();
237
238 param_size + optimizer_size + history_size
239 }
240}
241
242#[derive(Debug, Clone, PartialEq)]
244struct CheckpointMetadata {
245 epoch: usize,
247 val_loss: Option<f64>,
249 path: PathBuf,
251}
252
253pub struct CheckpointCallback {
255 pub checkpoint_dir: PathBuf,
257 pub save_frequency: usize,
259 pub save_best_only: bool,
261 pub keep_top_k: Option<usize>,
263 best_val_loss: Option<f64>,
265 saved_checkpoints: Vec<CheckpointMetadata>,
267}
268
269impl CheckpointCallback {
270 pub fn new(checkpoint_dir: PathBuf, save_frequency: usize, save_best_only: bool) -> Self {
272 Self {
273 checkpoint_dir,
274 save_frequency,
275 save_best_only,
276 keep_top_k: None,
277 best_val_loss: None,
278 saved_checkpoints: Vec::new(),
279 }
280 }
281
282 pub fn with_cleanup(
307 checkpoint_dir: PathBuf,
308 save_frequency: usize,
309 save_best_only: bool,
310 keep_top_k: usize,
311 ) -> Self {
312 Self {
313 checkpoint_dir,
314 save_frequency,
315 save_best_only,
316 keep_top_k: Some(keep_top_k),
317 best_val_loss: None,
318 saved_checkpoints: Vec::new(),
319 }
320 }
321
322 pub fn num_saved_checkpoints(&self) -> usize {
324 self.saved_checkpoints.len()
325 }
326
327 pub fn cleanup_checkpoints(&mut self) -> TrainResult<usize> {
332 let keep_top_k = match self.keep_top_k {
333 Some(k) => k,
334 None => return Ok(0), };
336
337 if self.saved_checkpoints.len() <= keep_top_k {
338 return Ok(0); }
340
341 self.saved_checkpoints.sort_by(|a, b| {
344 match (a.val_loss, b.val_loss) {
345 (Some(a_loss), Some(b_loss)) => a_loss
346 .partial_cmp(&b_loss)
347 .unwrap_or(std::cmp::Ordering::Equal),
348 (Some(_), None) => std::cmp::Ordering::Less, (None, Some(_)) => std::cmp::Ordering::Greater, (None, None) => b.epoch.cmp(&a.epoch), }
352 });
353
354 let to_remove: Vec<CheckpointMetadata> =
356 self.saved_checkpoints.drain(keep_top_k..).collect();
357
358 let mut deleted_count = 0;
359 for checkpoint in to_remove {
360 if let Err(e) = std::fs::remove_file(&checkpoint.path) {
361 eprintln!(
362 "Warning: Failed to delete checkpoint {:?}: {}",
363 checkpoint.path, e
364 );
365 } else {
366 deleted_count += 1;
367 }
368 }
369
370 Ok(deleted_count)
371 }
372
373 fn save_checkpoint(&mut self, epoch: usize, state: &TrainingState) -> TrainResult<()> {
375 let checkpoint_path = self
376 .checkpoint_dir
377 .join(format!("checkpoint_epoch_{}.json", epoch));
378
379 let mut checkpoint = HashMap::new();
381 checkpoint.insert("epoch".to_string(), epoch as f64);
382 checkpoint.insert("train_loss".to_string(), state.train_loss);
383 if let Some(val_loss) = state.val_loss {
384 checkpoint.insert("val_loss".to_string(), val_loss);
385 }
386
387 let json = serde_json::to_string_pretty(&checkpoint).map_err(|e| {
389 TrainError::CheckpointError(format!("Failed to serialize checkpoint: {}", e))
390 })?;
391
392 std::fs::create_dir_all(&self.checkpoint_dir).map_err(|e| {
393 TrainError::CheckpointError(format!("Failed to create checkpoint directory: {}", e))
394 })?;
395
396 std::fs::write(&checkpoint_path, json).map_err(|e| {
397 TrainError::CheckpointError(format!("Failed to write checkpoint: {}", e))
398 })?;
399
400 let metadata = CheckpointMetadata {
402 epoch,
403 val_loss: state.val_loss,
404 path: checkpoint_path.clone(),
405 };
406 self.saved_checkpoints.push(metadata);
407
408 if self.keep_top_k.is_some() {
410 let deleted = self.cleanup_checkpoints()?;
411 if deleted > 0 {
412 println!(
413 "Checkpoint saved to {:?} (deleted {} old checkpoints)",
414 checkpoint_path, deleted
415 );
416 } else {
417 println!("Checkpoint saved to {:?}", checkpoint_path);
418 }
419 } else {
420 println!("Checkpoint saved to {:?}", checkpoint_path);
421 }
422
423 Ok(())
424 }
425}
426
427impl Callback for CheckpointCallback {
428 fn on_epoch_end(&mut self, epoch: usize, state: &TrainingState) -> TrainResult<()> {
429 if !epoch.is_multiple_of(self.save_frequency) {
430 return Ok(());
431 }
432
433 if self.save_best_only {
434 if let Some(val_loss) = state.val_loss {
435 let should_save = self
436 .best_val_loss
437 .map(|best| val_loss < best)
438 .unwrap_or(true);
439
440 if should_save {
441 self.best_val_loss = Some(val_loss);
442 self.save_checkpoint(epoch, state)?;
443 }
444 }
445 } else {
446 self.save_checkpoint(epoch, state)?;
447 }
448
449 Ok(())
450 }
451}
452
453#[cfg(test)]
454mod tests {
455 use super::*;
456 use scirs2_core::ndarray::Array2;
457 use std::env::temp_dir;
458
459 fn create_test_state() -> TrainingState {
460 TrainingState {
461 epoch: 0,
462 batch: 0,
463 train_loss: 1.0,
464 val_loss: Some(0.8),
465 batch_loss: 0.5,
466 learning_rate: 0.001,
467 metrics: HashMap::new(),
468 }
469 }
470
471 #[test]
472 fn test_checkpoint_callback() {
473 let checkpoint_dir = temp_dir().join("tensorlogic_test_checkpoints");
474 let mut callback = CheckpointCallback::new(checkpoint_dir.clone(), 1, false);
475 let state = create_test_state();
476
477 callback.on_epoch_end(0, &state).unwrap();
478
479 let checkpoint_path = checkpoint_dir.join("checkpoint_epoch_0.json");
481 assert!(checkpoint_path.exists());
482
483 std::fs::remove_dir_all(checkpoint_dir).ok();
485 }
486
487 #[test]
488 fn test_training_checkpoint_save_load() {
489 let mut parameters = HashMap::new();
491 parameters.insert("weight".to_string(), Array2::from_elem((2, 3), 1.5));
492 parameters.insert("bias".to_string(), Array2::from_elem((1, 3), 0.5));
493
494 let state = TrainingState {
496 epoch: 5,
497 batch: 100,
498 train_loss: 0.75,
499 val_loss: Some(0.85),
500 batch_loss: 0.72,
501 learning_rate: 0.001,
502 metrics: HashMap::new(),
503 };
504
505 let optimizer_state = {
507 let mut state = HashMap::new();
508 state.insert("momentum_weight".to_string(), vec![0.1, 0.2, 0.3]);
509 state.insert("momentum_bias".to_string(), vec![0.05]);
510 state
511 };
512
513 let checkpoint = TrainingCheckpoint::new(
515 5,
516 ¶meters,
517 &optimizer_state,
518 None,
519 &state,
520 &[1.0, 0.9, 0.8, 0.77, 0.75],
521 &[1.1, 0.95, 0.88, 0.87, 0.85],
522 &HashMap::new(),
523 Some(0.85),
524 );
525
526 let checkpoint_path = temp_dir().join("test_training_checkpoint.json");
528 checkpoint.save(&checkpoint_path).unwrap();
529
530 assert!(checkpoint_path.exists());
532
533 let loaded = TrainingCheckpoint::load(&checkpoint_path).unwrap();
535
536 assert_eq!(loaded.epoch, 5);
538 assert_eq!(loaded.train_loss, 0.75);
539 assert_eq!(loaded.val_loss, Some(0.85));
540 assert_eq!(loaded.learning_rate, 0.001);
541 assert_eq!(loaded.train_loss_history.len(), 5);
542 assert_eq!(loaded.val_loss_history.len(), 5);
543 assert_eq!(loaded.best_val_loss, Some(0.85));
544
545 assert_eq!(loaded.parameters.len(), 2);
547 assert!(loaded.parameters.contains_key("weight"));
548 assert!(loaded.parameters.contains_key("bias"));
549
550 assert_eq!(loaded.optimizer_state.len(), 2);
552 assert!(loaded.optimizer_state.contains_key("momentum_weight"));
553
554 std::fs::remove_file(checkpoint_path).ok();
556 }
557
558 #[test]
559 fn test_training_checkpoint_with_metrics() {
560 let mut parameters = HashMap::new();
561 parameters.insert("w".to_string(), Array2::zeros((2, 2)));
562
563 let state = create_test_state();
564 let optimizer_state = HashMap::new();
565
566 let mut metrics_history = HashMap::new();
568 metrics_history.insert("accuracy".to_string(), vec![0.5, 0.6, 0.7]);
569 metrics_history.insert("f1_score".to_string(), vec![0.45, 0.55, 0.65]);
570
571 let checkpoint = TrainingCheckpoint::new(
572 2,
573 ¶meters,
574 &optimizer_state,
575 None,
576 &state,
577 &[1.0, 0.8, 0.6],
578 &[1.1, 0.9, 0.7],
579 &metrics_history,
580 Some(0.7),
581 );
582
583 let checkpoint_path = temp_dir().join("test_checkpoint_with_metrics.json");
584 checkpoint.save(&checkpoint_path).unwrap();
585
586 let loaded = TrainingCheckpoint::load(&checkpoint_path).unwrap();
587
588 assert_eq!(loaded.metrics_history.len(), 2);
590 assert!(loaded.metrics_history.contains_key("accuracy"));
591 assert!(loaded.metrics_history.contains_key("f1_score"));
592 assert_eq!(loaded.metrics_history["accuracy"].len(), 3);
593
594 std::fs::remove_file(checkpoint_path).ok();
595 }
596
597 #[test]
598 fn test_checkpoint_compression_gzip() {
599 let mut parameters = HashMap::new();
600 parameters.insert("weights".to_string(), Array2::from_elem((100, 100), 1.5));
601
602 let state = create_test_state();
603 let optimizer_state = HashMap::new();
604
605 let checkpoint = TrainingCheckpoint::new(
606 10,
607 ¶meters,
608 &optimizer_state,
609 None,
610 &state,
611 &vec![1.0; 100],
612 &vec![0.9; 100],
613 &HashMap::new(),
614 Some(0.5),
615 );
616
617 let compressed_path = temp_dir().join("test_checkpoint_compressed.json.gz");
619 checkpoint
620 .save_with_compression(&compressed_path, CheckpointCompression::Gzip)
621 .unwrap();
622
623 assert!(compressed_path.exists());
625
626 let loaded = TrainingCheckpoint::load(&compressed_path).unwrap();
628
629 assert_eq!(loaded.epoch, 10);
631 assert_eq!(loaded.parameters.len(), 1);
632 assert_eq!(loaded.parameters["weights"].len(), 10000); let uncompressed_path = temp_dir().join("test_checkpoint_uncompressed.json");
636 checkpoint.save(&uncompressed_path).unwrap();
637
638 let compressed_size = std::fs::metadata(&compressed_path).unwrap().len();
639 let uncompressed_size = std::fs::metadata(&uncompressed_path).unwrap().len();
640
641 assert!(
643 compressed_size < uncompressed_size,
644 "Compressed size {} should be less than uncompressed size {}",
645 compressed_size,
646 uncompressed_size
647 );
648
649 std::fs::remove_file(compressed_path).ok();
651 std::fs::remove_file(uncompressed_path).ok();
652 }
653
654 #[test]
655 fn test_checkpoint_compression_fast_vs_best() {
656 let mut parameters = HashMap::new();
657 parameters.insert("weights".to_string(), Array2::from_elem((50, 50), 2.0));
658
659 let state = create_test_state();
660 let optimizer_state = HashMap::new();
661
662 let checkpoint = TrainingCheckpoint::new(
663 5,
664 ¶meters,
665 &optimizer_state,
666 None,
667 &state,
668 &vec![1.0; 50],
669 &vec![0.8; 50],
670 &HashMap::new(),
671 None,
672 );
673
674 let fast_path = temp_dir().join("test_checkpoint_fast.json.gz");
676 checkpoint
677 .save_with_compression(&fast_path, CheckpointCompression::GzipFast)
678 .unwrap();
679
680 let best_path = temp_dir().join("test_checkpoint_best.json.gz");
682 checkpoint
683 .save_with_compression(&best_path, CheckpointCompression::GzipBest)
684 .unwrap();
685
686 let loaded_fast = TrainingCheckpoint::load(&fast_path).unwrap();
688 let loaded_best = TrainingCheckpoint::load(&best_path).unwrap();
689
690 assert_eq!(loaded_fast.epoch, 5);
691 assert_eq!(loaded_best.epoch, 5);
692 assert_eq!(
693 loaded_fast.parameters["weights"],
694 loaded_best.parameters["weights"]
695 );
696
697 std::fs::remove_file(fast_path).ok();
699 std::fs::remove_file(best_path).ok();
700 }
701
702 #[test]
703 fn test_checkpoint_estimated_size() {
704 let mut parameters = HashMap::new();
705 parameters.insert("w1".to_string(), Array2::from_elem((10, 10), 1.0));
706 parameters.insert("w2".to_string(), Array2::from_elem((5, 5), 1.0));
707
708 let state = create_test_state();
709 let optimizer_state = HashMap::new();
710
711 let train_loss_history: [f64; 10] = [1.0; 10];
712 let val_loss_history: [f64; 10] = [0.9; 10];
713 let checkpoint = TrainingCheckpoint::new(
714 1,
715 ¶meters,
716 &optimizer_state,
717 None,
718 &state,
719 &train_loss_history,
720 &val_loss_history,
721 &HashMap::new(),
722 None,
723 );
724
725 let size = checkpoint.estimated_size();
726 assert!(size > 0);
728 assert_eq!(
729 size,
730 (100 + 25) * std::mem::size_of::<f64>() + 20 * std::mem::size_of::<f64>()
731 );
732 }
733
734 #[test]
735 fn test_checkpoint_auto_detect_compression() {
736 let mut parameters = HashMap::new();
737 parameters.insert("w".to_string(), Array2::from_elem((5, 5), 1.0));
738
739 let state = create_test_state();
740
741 let checkpoint = TrainingCheckpoint::new(
742 1,
743 ¶meters,
744 &HashMap::new(),
745 None,
746 &state,
747 &[1.0],
748 &[0.9],
749 &HashMap::new(),
750 None,
751 );
752
753 let uncompressed_path = temp_dir().join("test_auto_detect.json");
755 checkpoint.save(&uncompressed_path).unwrap();
756
757 let compressed_path = temp_dir().join("test_auto_detect.json.gz");
759 checkpoint
760 .save_with_compression(&compressed_path, CheckpointCompression::Gzip)
761 .unwrap();
762
763 let loaded_uncompressed = TrainingCheckpoint::load(&uncompressed_path).unwrap();
765 let loaded_compressed = TrainingCheckpoint::load(&compressed_path).unwrap();
766
767 assert_eq!(loaded_uncompressed.epoch, loaded_compressed.epoch);
768 assert_eq!(loaded_uncompressed.parameters, loaded_compressed.parameters);
769
770 std::fs::remove_file(uncompressed_path).ok();
772 std::fs::remove_file(compressed_path).ok();
773 }
774
775 #[test]
776 fn test_checkpoint_auto_cleanup() {
777 let checkpoint_dir = temp_dir().join("tensorlogic_test_auto_cleanup");
778 std::fs::create_dir_all(&checkpoint_dir).ok();
779
780 let mut callback = CheckpointCallback::with_cleanup(checkpoint_dir.clone(), 1, false, 3);
782
783 let val_losses = [0.9, 0.7, 0.8, 0.6, 0.5]; for (epoch, &val_loss) in val_losses.iter().enumerate() {
787 let mut state = create_test_state();
788 state.val_loss = Some(val_loss);
789 callback.save_checkpoint(epoch, &state).unwrap();
790 }
791
792 assert_eq!(callback.num_saved_checkpoints(), 3);
794
795 assert!(checkpoint_dir.join("checkpoint_epoch_4.json").exists()); assert!(checkpoint_dir.join("checkpoint_epoch_3.json").exists()); assert!(checkpoint_dir.join("checkpoint_epoch_1.json").exists()); assert!(!checkpoint_dir.join("checkpoint_epoch_0.json").exists()); assert!(!checkpoint_dir.join("checkpoint_epoch_2.json").exists()); std::fs::remove_dir_all(checkpoint_dir).ok();
806 }
807
808 #[test]
809 fn test_checkpoint_no_cleanup_when_disabled() {
810 let checkpoint_dir = temp_dir().join("tensorlogic_test_no_cleanup");
811 std::fs::create_dir_all(&checkpoint_dir).ok();
812
813 let mut callback = CheckpointCallback::new(checkpoint_dir.clone(), 1, false);
815
816 for epoch in 0..5 {
818 let state = create_test_state();
819 callback.save_checkpoint(epoch, &state).unwrap();
820 }
821
822 for epoch in 0..5 {
824 let path = checkpoint_dir.join(format!("checkpoint_epoch_{}.json", epoch));
825 assert!(path.exists(), "Checkpoint {} should exist", epoch);
826 }
827
828 std::fs::remove_dir_all(checkpoint_dir).ok();
830 }
831
832 #[test]
833 fn test_checkpoint_manual_cleanup() {
834 let checkpoint_dir = temp_dir().join("tensorlogic_test_manual_cleanup");
835 std::fs::create_dir_all(&checkpoint_dir).ok();
836
837 let mut callback = CheckpointCallback::with_cleanup(checkpoint_dir.clone(), 1, false, 2);
839
840 let val_losses = [0.8, 0.6, 0.9, 0.5];
842 for (epoch, &val_loss) in val_losses.iter().enumerate() {
843 let mut state = create_test_state();
844 state.val_loss = Some(val_loss);
845 callback.save_checkpoint(epoch, &state).unwrap();
846 }
847
848 assert_eq!(callback.num_saved_checkpoints(), 2);
850
851 let deleted = callback.cleanup_checkpoints().unwrap();
853 assert_eq!(deleted, 0);
854 assert_eq!(callback.num_saved_checkpoints(), 2);
855
856 std::fs::remove_dir_all(checkpoint_dir).ok();
858 }
859
860 #[test]
861 fn test_checkpoint_cleanup_without_val_loss() {
862 let checkpoint_dir = temp_dir().join("tensorlogic_test_cleanup_no_val_loss");
863 std::fs::create_dir_all(&checkpoint_dir).ok();
864
865 let mut callback = CheckpointCallback::with_cleanup(checkpoint_dir.clone(), 1, false, 2);
867
868 for epoch in 0..4 {
870 let mut state = create_test_state();
871 state.val_loss = None; callback.save_checkpoint(epoch, &state).unwrap();
873 }
874
875 assert_eq!(callback.num_saved_checkpoints(), 2);
877
878 assert!(checkpoint_dir.join("checkpoint_epoch_3.json").exists());
880 assert!(checkpoint_dir.join("checkpoint_epoch_2.json").exists());
881
882 std::fs::remove_dir_all(checkpoint_dir).ok();
884 }
885
886 #[test]
887 fn test_checkpoint_with_save_best_only_and_cleanup() {
888 let checkpoint_dir = temp_dir().join("tensorlogic_test_best_and_cleanup");
889 std::fs::create_dir_all(&checkpoint_dir).ok();
890
891 let mut callback = CheckpointCallback::with_cleanup(checkpoint_dir.clone(), 1, true, 2);
893
894 let val_losses = [0.9, 0.7, 0.8, 0.6]; for (epoch, &val_loss) in val_losses.iter().enumerate() {
898 let mut state = create_test_state();
899 state.val_loss = Some(val_loss);
900 callback.on_epoch_end(epoch, &state).unwrap();
901 }
902
903 assert!(callback.num_saved_checkpoints() <= 2);
905
906 std::fs::remove_dir_all(checkpoint_dir).ok();
908 }
909}