1use std::collections::{HashMap, VecDeque};
7use std::path::{Path, PathBuf};
8
9#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
15pub struct ParamState {
16 pub name: String,
17 pub first_moment: Vec<f64>,
19 pub second_moment: Vec<f64>,
21 pub step: u64,
22 pub shape: Vec<usize>,
23}
24
25impl ParamState {
26 pub fn new(
28 name: impl Into<String>,
29 first_moment: Vec<f64>,
30 second_moment: Vec<f64>,
31 step: u64,
32 shape: Vec<usize>,
33 ) -> Self {
34 Self {
35 name: name.into(),
36 first_moment,
37 second_moment,
38 step,
39 shape,
40 }
41 }
42}
43
44#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
46pub struct CheckpointMetadata {
47 pub created_at_step: u64,
48 pub loss: Option<f64>,
49 pub val_loss: Option<f64>,
50 pub extra: HashMap<String, String>,
51}
52
53impl CheckpointMetadata {
54 pub fn new(created_at_step: u64) -> Self {
56 Self {
57 created_at_step,
58 loss: None,
59 val_loss: None,
60 extra: HashMap::new(),
61 }
62 }
63}
64
65#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
67pub struct OptimizerCheckpoint {
68 pub step: u64,
69 pub epoch: u32,
70 pub optimizer_name: String,
71 pub param_states: HashMap<String, ParamState>,
73 pub hyperparams: HashMap<String, f64>,
74 pub metadata: CheckpointMetadata,
75}
76
77impl OptimizerCheckpoint {
78 pub fn new(optimizer_name: impl Into<String>, step: u64, epoch: u32) -> Self {
80 Self {
81 step,
82 epoch,
83 optimizer_name: optimizer_name.into(),
84 param_states: HashMap::new(),
85 hyperparams: HashMap::new(),
86 metadata: CheckpointMetadata::new(step),
87 }
88 }
89
90 pub fn add_param_state(&mut self, name: impl Into<String>, state: ParamState) {
92 self.param_states.insert(name.into(), state);
93 }
94
95 pub fn set_hyperparam(&mut self, key: impl Into<String>, value: f64) {
97 self.hyperparams.insert(key.into(), value);
98 }
99
100 pub fn get_hyperparam(&self, key: &str) -> Option<f64> {
102 self.hyperparams.get(key).copied()
103 }
104
105 pub fn num_params(&self) -> usize {
107 self.param_states.len()
108 }
109
110 pub fn total_elements(&self) -> usize {
112 self.param_states
113 .values()
114 .map(|ps| ps.first_moment.len())
115 .sum()
116 }
117}
118
119#[derive(Debug, Clone)]
125pub enum CheckpointFormat {
126 Binary,
128 Text,
130}
131
132impl CheckpointFormat {
133 fn file_extension(&self) -> &'static str {
134 match self {
135 CheckpointFormat::Binary => "tlck",
136 CheckpointFormat::Text => "tlckt",
137 }
138 }
139}
140
141#[derive(Debug, Clone)]
147pub enum CheckpointError {
148 IoError(String),
149 SerializationError(String),
150 DeserializationError(String),
151 CheckpointNotFound { step: u64 },
152 NoCheckpointsAvailable,
153 InvalidFormat(String),
154 DirectoryCreationFailed(String),
155}
156
157impl std::fmt::Display for CheckpointError {
158 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
159 match self {
160 CheckpointError::IoError(msg) => write!(f, "IO error: {msg}"),
161 CheckpointError::SerializationError(msg) => {
162 write!(f, "Serialization error: {msg}")
163 }
164 CheckpointError::DeserializationError(msg) => {
165 write!(f, "Deserialization error: {msg}")
166 }
167 CheckpointError::CheckpointNotFound { step } => {
168 write!(f, "Checkpoint not found for step {step}")
169 }
170 CheckpointError::NoCheckpointsAvailable => {
171 write!(f, "No checkpoints are available")
172 }
173 CheckpointError::InvalidFormat(msg) => write!(f, "Invalid format: {msg}"),
174 CheckpointError::DirectoryCreationFailed(msg) => {
175 write!(f, "Directory creation failed: {msg}")
176 }
177 }
178 }
179}
180
181impl std::error::Error for CheckpointError {}
182
183fn encode_f64_slice(values: &[f64]) -> String {
189 values
190 .iter()
191 .map(|v| v.to_string())
192 .collect::<Vec<_>>()
193 .join(",")
194}
195
196fn decode_f64_slice(s: &str) -> Result<Vec<f64>, CheckpointError> {
198 if s.is_empty() {
199 return Ok(Vec::new());
200 }
201 s.split(',')
202 .map(|tok| {
203 tok.trim()
204 .parse::<f64>()
205 .map_err(|e| CheckpointError::DeserializationError(format!("f64 parse: {e}")))
206 })
207 .collect()
208}
209
210fn encode_usize_slice(values: &[usize]) -> String {
212 values
213 .iter()
214 .map(|v| v.to_string())
215 .collect::<Vec<_>>()
216 .join(",")
217}
218
219fn decode_usize_slice(s: &str) -> Result<Vec<usize>, CheckpointError> {
221 if s.is_empty() {
222 return Ok(Vec::new());
223 }
224 s.split(',')
225 .map(|tok| {
226 tok.trim()
227 .parse::<usize>()
228 .map_err(|e| CheckpointError::DeserializationError(format!("usize parse: {e}")))
229 })
230 .collect()
231}
232
233fn serialize_text(ckpt: &OptimizerCheckpoint) -> Vec<u8> {
235 let mut out = String::new();
236
237 out.push_str("section=header\n");
239 out.push_str(&format!("step={}\n", ckpt.step));
240 out.push_str(&format!("epoch={}\n", ckpt.epoch));
241 out.push_str(&format!("optimizer_name={}\n", ckpt.optimizer_name));
242 out.push_str(&format!(
243 "created_at_step={}\n",
244 ckpt.metadata.created_at_step
245 ));
246 if let Some(loss) = ckpt.metadata.loss {
247 out.push_str(&format!("loss={loss}\n"));
248 }
249 if let Some(val_loss) = ckpt.metadata.val_loss {
250 out.push_str(&format!("val_loss={val_loss}\n"));
251 }
252 for (k, v) in &ckpt.metadata.extra {
253 out.push_str(&format!("extra.{k}={v}\n"));
254 }
255
256 out.push_str("\n---\n");
257
258 out.push_str("section=hyperparams\n");
260 for (k, v) in &ckpt.hyperparams {
261 out.push_str(&format!("hp.{k}={v}\n"));
262 }
263
264 out.push_str("\n---\n");
265
266 out.push_str("section=param_states\n");
268 for (param_name, ps) in &ckpt.param_states {
269 out.push_str(&format!("param.name={param_name}\n"));
270 out.push_str(&format!(
271 "param.first_moment={}\n",
272 encode_f64_slice(&ps.first_moment)
273 ));
274 out.push_str(&format!(
275 "param.second_moment={}\n",
276 encode_f64_slice(&ps.second_moment)
277 ));
278 out.push_str(&format!("param.step={}\n", ps.step));
279 out.push_str(&format!("param.shape={}\n", encode_usize_slice(&ps.shape)));
280 out.push_str("param.end\n");
281 }
282
283 out.into_bytes()
284}
285
286fn deserialize_text(bytes: &[u8]) -> Result<OptimizerCheckpoint, CheckpointError> {
288 let text = std::str::from_utf8(bytes)
289 .map_err(|e| CheckpointError::DeserializationError(format!("UTF-8: {e}")))?;
290
291 let mut step: Option<u64> = None;
292 let mut epoch: Option<u32> = None;
293 let mut optimizer_name: Option<String> = None;
294 let mut created_at_step: u64 = 0;
295 let mut loss: Option<f64> = None;
296 let mut val_loss: Option<f64> = None;
297 let mut extra: HashMap<String, String> = HashMap::new();
298 let mut hyperparams: HashMap<String, f64> = HashMap::new();
299 let mut param_states: HashMap<String, ParamState> = HashMap::new();
300
301 let mut cur_name: Option<String> = None;
303 let mut cur_first: Vec<f64> = Vec::new();
304 let mut cur_second: Vec<f64> = Vec::new();
305 let mut cur_step: u64 = 0;
306 let mut cur_shape: Vec<usize> = Vec::new();
307
308 for raw_line in text.lines() {
309 let line = raw_line.trim();
310 if line.is_empty() || line == "---" {
311 continue;
312 }
313 if line.starts_with("section=") {
314 continue;
315 }
316 if line == "param.end" {
317 if let Some(name) = cur_name.take() {
318 param_states.insert(
319 name.clone(),
320 ParamState {
321 name,
322 first_moment: std::mem::take(&mut cur_first),
323 second_moment: std::mem::take(&mut cur_second),
324 step: cur_step,
325 shape: std::mem::take(&mut cur_shape),
326 },
327 );
328 }
329 cur_step = 0;
330 continue;
331 }
332
333 let (key, value) = line.split_once('=').ok_or_else(|| {
334 CheckpointError::DeserializationError(format!("Missing '=' in line: {line}"))
335 })?;
336
337 match key {
338 "step" => {
339 step =
340 Some(value.parse::<u64>().map_err(|e| {
341 CheckpointError::DeserializationError(format!("step: {e}"))
342 })?);
343 }
344 "epoch" => {
345 epoch =
346 Some(value.parse::<u32>().map_err(|e| {
347 CheckpointError::DeserializationError(format!("epoch: {e}"))
348 })?);
349 }
350 "optimizer_name" => {
351 optimizer_name = Some(value.to_owned());
352 }
353 "created_at_step" => {
354 created_at_step = value.parse::<u64>().map_err(|e| {
355 CheckpointError::DeserializationError(format!("created_at_step: {e}"))
356 })?;
357 }
358 "loss" => {
359 loss =
360 Some(value.parse::<f64>().map_err(|e| {
361 CheckpointError::DeserializationError(format!("loss: {e}"))
362 })?);
363 }
364 "val_loss" => {
365 val_loss = Some(value.parse::<f64>().map_err(|e| {
366 CheckpointError::DeserializationError(format!("val_loss: {e}"))
367 })?);
368 }
369 "param.name" => {
370 cur_name = Some(value.to_owned());
371 }
372 "param.first_moment" => {
373 cur_first = decode_f64_slice(value)?;
374 }
375 "param.second_moment" => {
376 cur_second = decode_f64_slice(value)?;
377 }
378 "param.step" => {
379 cur_step = value.parse::<u64>().map_err(|e| {
380 CheckpointError::DeserializationError(format!("param.step: {e}"))
381 })?;
382 }
383 "param.shape" => {
384 cur_shape = decode_usize_slice(value)?;
385 }
386 other if other.starts_with("hp.") => {
387 let hp_key = other.trim_start_matches("hp.");
388 let hp_val = value.parse::<f64>().map_err(|e| {
389 CheckpointError::DeserializationError(format!("hyperparam {hp_key}: {e}"))
390 })?;
391 hyperparams.insert(hp_key.to_owned(), hp_val);
392 }
393 other if other.starts_with("extra.") => {
394 let ex_key = other.trim_start_matches("extra.");
395 extra.insert(ex_key.to_owned(), value.to_owned());
396 }
397 _ => {} }
399 }
400
401 let step =
402 step.ok_or_else(|| CheckpointError::DeserializationError("missing field: step".into()))?;
403 let epoch = epoch
404 .ok_or_else(|| CheckpointError::DeserializationError("missing field: epoch".into()))?;
405 let optimizer_name = optimizer_name.ok_or_else(|| {
406 CheckpointError::DeserializationError("missing field: optimizer_name".into())
407 })?;
408
409 Ok(OptimizerCheckpoint {
410 step,
411 epoch,
412 optimizer_name,
413 param_states,
414 hyperparams,
415 metadata: CheckpointMetadata {
416 created_at_step,
417 loss,
418 val_loss,
419 extra,
420 },
421 })
422}
423
424const BINARY_MAGIC: [u8; 4] = [0x54, 0x4C, 0x43, 0x4B]; const BINARY_VERSION: u32 = 1;
427
428pub fn serialize_checkpoint(
430 ckpt: &OptimizerCheckpoint,
431 format: CheckpointFormat,
432) -> Result<Vec<u8>, CheckpointError> {
433 match format {
434 CheckpointFormat::Text => Ok(serialize_text(ckpt)),
435 CheckpointFormat::Binary => {
436 let json = serde_json::to_vec(ckpt)
438 .map_err(|e| CheckpointError::SerializationError(format!("JSON: {e}")))?;
439
440 let payload_len = json.len() as u32;
442 let mut out = Vec::with_capacity(12 + json.len());
443 out.extend_from_slice(&BINARY_MAGIC);
444 out.extend_from_slice(&BINARY_VERSION.to_be_bytes());
445 out.extend_from_slice(&payload_len.to_be_bytes());
446 out.extend_from_slice(&json);
447 Ok(out)
448 }
449 }
450}
451
452pub fn deserialize_checkpoint(
454 bytes: &[u8],
455 format: CheckpointFormat,
456) -> Result<OptimizerCheckpoint, CheckpointError> {
457 match format {
458 CheckpointFormat::Text => deserialize_text(bytes),
459 CheckpointFormat::Binary => {
460 if bytes.len() < 12 {
462 return Err(CheckpointError::InvalidFormat(
463 "binary checkpoint too short".into(),
464 ));
465 }
466 if bytes[..4] != BINARY_MAGIC {
467 return Err(CheckpointError::InvalidFormat(
468 "bad magic bytes — not a TLCK checkpoint".into(),
469 ));
470 }
471 let version = u32::from_be_bytes(
472 bytes[4..8]
473 .try_into()
474 .map_err(|_| CheckpointError::InvalidFormat("version bytes".into()))?,
475 );
476 if version != BINARY_VERSION {
477 return Err(CheckpointError::InvalidFormat(format!(
478 "unsupported version {version}"
479 )));
480 }
481 let payload_len = u32::from_be_bytes(
482 bytes[8..12]
483 .try_into()
484 .map_err(|_| CheckpointError::InvalidFormat("length bytes".into()))?,
485 ) as usize;
486 let payload_end = 12 + payload_len;
487 if bytes.len() < payload_end {
488 return Err(CheckpointError::InvalidFormat(
489 "truncated binary checkpoint".into(),
490 ));
491 }
492 let json = &bytes[12..payload_end];
493 serde_json::from_slice(json)
494 .map_err(|e| CheckpointError::DeserializationError(format!("JSON: {e}")))
495 }
496 }
497}
498
499pub struct CheckpointManager {
508 pub dir: PathBuf,
509 pub max_to_keep: usize,
510 pub format: CheckpointFormat,
511 saved: Vec<PathBuf>,
513}
514
515impl CheckpointManager {
516 pub fn new(
518 dir: impl AsRef<Path>,
519 max_to_keep: usize,
520 format: CheckpointFormat,
521 ) -> Result<Self, CheckpointError> {
522 let dir = dir.as_ref().to_path_buf();
523 std::fs::create_dir_all(&dir).map_err(|e| {
524 CheckpointError::DirectoryCreationFailed(format!("{}: {e}", dir.display()))
525 })?;
526 Ok(Self {
527 dir,
528 max_to_keep,
529 format,
530 saved: Vec::new(),
531 })
532 }
533
534 fn checkpoint_filename(step: u64, format: &CheckpointFormat) -> String {
536 format!("ckpt-step-{:012}.{}", step, format.file_extension())
537 }
538
539 pub fn save(&mut self, ckpt: &OptimizerCheckpoint) -> Result<PathBuf, CheckpointError> {
541 let filename = Self::checkpoint_filename(ckpt.step, &self.format);
542 let path = self.dir.join(&filename);
543
544 let bytes = serialize_checkpoint(ckpt, self.format.clone())?;
545 std::fs::write(&path, &bytes)
546 .map_err(|e| CheckpointError::IoError(format!("write {}: {e}", path.display())))?;
547
548 self.saved.push(path.clone());
549 self.prune_old()?;
550 Ok(path)
551 }
552
553 pub fn load_latest(&self) -> Result<OptimizerCheckpoint, CheckpointError> {
555 let path = self
556 .saved
557 .last()
558 .ok_or(CheckpointError::NoCheckpointsAvailable)?;
559 self.load_from_path(path)
560 }
561
562 pub fn load_at_step(&self, step: u64) -> Result<OptimizerCheckpoint, CheckpointError> {
564 let filename = Self::checkpoint_filename(step, &self.format);
565 let path = self.dir.join(&filename);
566 if !self.saved.iter().any(|p| p == &path) {
567 return Err(CheckpointError::CheckpointNotFound { step });
568 }
569 self.load_from_path(&path)
570 }
571
572 pub fn list(&self) -> Vec<(u64, &Path)> {
574 self.saved
575 .iter()
576 .filter_map(|p| {
577 let stem = p.file_stem()?.to_str()?;
579 let step_str = stem.strip_prefix("ckpt-step-")?;
580 let step = step_str.parse::<u64>().ok()?;
581 Some((step, p.as_path()))
582 })
583 .collect()
584 }
585
586 pub fn count(&self) -> usize {
588 self.saved.len()
589 }
590
591 fn load_from_path(&self, path: &Path) -> Result<OptimizerCheckpoint, CheckpointError> {
594 let bytes = std::fs::read(path)
595 .map_err(|e| CheckpointError::IoError(format!("read {}: {e}", path.display())))?;
596 deserialize_checkpoint(&bytes, self.format.clone())
597 }
598
599 fn prune_old(&mut self) -> Result<(), CheckpointError> {
601 while self.saved.len() > self.max_to_keep {
602 let oldest = self.saved.remove(0);
603 if oldest.exists() {
604 std::fs::remove_file(&oldest).map_err(|e| {
605 CheckpointError::IoError(format!("delete {}: {e}", oldest.display()))
606 })?;
607 }
608 }
609 Ok(())
610 }
611}
612
613#[derive(Debug, Clone)]
622pub struct LossTracker {
623 pub window_size: usize,
624 history: VecDeque<f64>,
625}
626
627impl LossTracker {
628 pub fn new(window_size: usize) -> Self {
630 Self {
631 window_size,
632 history: VecDeque::with_capacity(window_size),
633 }
634 }
635
636 pub fn push(&mut self, loss: f64) {
638 if self.history.len() == self.window_size {
639 self.history.pop_front();
640 }
641 self.history.push_back(loss);
642 }
643
644 pub fn moving_average(&self) -> Option<f64> {
646 if self.history.is_empty() {
647 return None;
648 }
649 let sum: f64 = self.history.iter().sum();
650 Some(sum / self.history.len() as f64)
651 }
652
653 pub fn min(&self) -> Option<f64> {
655 self.history.iter().copied().reduce(f64::min)
656 }
657
658 pub fn max(&self) -> Option<f64> {
660 self.history.iter().copied().reduce(f64::max)
661 }
662
663 pub fn is_improving(&self, patience: usize) -> bool {
669 if self.history.len() <= patience {
670 return false;
671 }
672 let split = self.history.len() - patience;
673 let older_min = self.history.iter().take(split).copied().reduce(f64::min);
674 let recent_min = self.history.iter().skip(split).copied().reduce(f64::min);
675 match (older_min, recent_min) {
676 (Some(old), Some(new)) => new < old,
677 _ => false,
678 }
679 }
680
681 pub fn len(&self) -> usize {
683 self.history.len()
684 }
685
686 pub fn is_empty(&self) -> bool {
688 self.history.is_empty()
689 }
690}
691
692#[cfg(test)]
697mod tests {
698 use super::*;
699
700 fn make_ckpt(step: u64, epoch: u32) -> OptimizerCheckpoint {
702 let mut ckpt = OptimizerCheckpoint::new("adam", step, epoch);
703 ckpt.set_hyperparam("lr", 0.001);
704 ckpt.set_hyperparam("beta1", 0.9);
705 let ps = ParamState::new(
706 "layer0.weight",
707 vec![0.1, 0.2, 0.3],
708 vec![0.01, 0.02, 0.03],
709 step,
710 vec![3],
711 );
712 ckpt.add_param_state("layer0.weight", ps);
713 ckpt
714 }
715
716 #[test]
719 fn test_optimizer_checkpoint_new() {
720 let ckpt = OptimizerCheckpoint::new("sgd", 42, 3);
721 assert_eq!(ckpt.step, 42);
722 assert_eq!(ckpt.epoch, 3);
723 assert_eq!(ckpt.optimizer_name, "sgd");
724 }
725
726 #[test]
727 fn test_add_param_state() {
728 let mut ckpt = OptimizerCheckpoint::new("adam", 0, 0);
729 assert_eq!(ckpt.num_params(), 0);
730 let ps = ParamState::new("w", vec![1.0], vec![], 0, vec![1]);
731 ckpt.add_param_state("w", ps);
732 assert_eq!(ckpt.num_params(), 1);
733 }
734
735 #[test]
736 fn test_set_get_hyperparam() {
737 let mut ckpt = OptimizerCheckpoint::new("adam", 0, 0);
738 ckpt.set_hyperparam("lr", 3e-4);
739 let retrieved = ckpt.get_hyperparam("lr");
740 assert!(retrieved.is_some());
741 let diff = (retrieved.unwrap_or(0.0) - 3e-4).abs();
742 assert!(diff < 1e-12, "hyperparam roundtrip mismatch");
743 assert!(ckpt.get_hyperparam("missing").is_none());
744 }
745
746 #[test]
747 fn test_total_elements() {
748 let mut ckpt = OptimizerCheckpoint::new("adam", 0, 0);
749 ckpt.add_param_state(
750 "a",
751 ParamState::new("a", vec![1.0, 2.0], vec![], 0, vec![2]),
752 );
753 ckpt.add_param_state(
754 "b",
755 ParamState::new("b", vec![3.0, 4.0, 5.0], vec![], 0, vec![3]),
756 );
757 assert_eq!(ckpt.total_elements(), 5);
758 }
759
760 #[test]
763 fn test_serialize_text_roundtrip() {
764 let ckpt = make_ckpt(100, 2);
765 let bytes = serialize_checkpoint(&ckpt, CheckpointFormat::Text).expect("serialize text");
766 let loaded =
767 deserialize_checkpoint(&bytes, CheckpointFormat::Text).expect("deserialize text");
768 assert_eq!(loaded.step, 100);
769 assert_eq!(loaded.epoch, 2);
770 assert_eq!(loaded.optimizer_name, "adam");
771 }
772
773 #[test]
774 fn test_serialize_text_param_states() {
775 let ckpt = make_ckpt(50, 1);
776 let bytes = serialize_checkpoint(&ckpt, CheckpointFormat::Text).expect("serialize");
777 let loaded = deserialize_checkpoint(&bytes, CheckpointFormat::Text).expect("deserialize");
778 assert_eq!(loaded.num_params(), 1);
779 let ps = loaded
780 .param_states
781 .get("layer0.weight")
782 .expect("param not found");
783 assert_eq!(ps.first_moment, vec![0.1, 0.2, 0.3]);
784 assert_eq!(ps.second_moment, vec![0.01, 0.02, 0.03]);
785 assert_eq!(ps.shape, vec![3]);
786 }
787
788 #[test]
791 fn test_serialize_binary_roundtrip() {
792 let ckpt = make_ckpt(200, 5);
793 let bytes =
794 serialize_checkpoint(&ckpt, CheckpointFormat::Binary).expect("serialize binary");
795 assert_eq!(&bytes[..4], &BINARY_MAGIC);
797 let loaded =
798 deserialize_checkpoint(&bytes, CheckpointFormat::Binary).expect("deserialize binary");
799 assert_eq!(loaded.step, 200);
800 assert_eq!(loaded.epoch, 5);
801 assert_eq!(loaded.optimizer_name, "adam");
802 }
803
804 #[test]
805 fn test_serialize_hyperparams_roundtrip() {
806 let mut ckpt = OptimizerCheckpoint::new("rmsprop", 10, 0);
807 ckpt.set_hyperparam("alpha", 0.99);
808 ckpt.set_hyperparam("eps", 1e-8);
809
810 for format in [CheckpointFormat::Text, CheckpointFormat::Binary] {
811 let bytes = serialize_checkpoint(&ckpt, format.clone()).expect("serialize");
812 let loaded = deserialize_checkpoint(&bytes, format).expect("deserialize");
813 let alpha = loaded.get_hyperparam("alpha").expect("alpha");
814 let eps = loaded.get_hyperparam("eps").expect("eps");
815 assert!((alpha - 0.99).abs() < 1e-12);
816 assert!((eps - 1e-8).abs() < 1e-20);
817 }
818 }
819
820 fn tmp_dir(suffix: &str) -> PathBuf {
823 let mut p = std::env::temp_dir();
824 p.push(format!("tl_ckpt_test_{suffix}_{}", std::process::id()));
825 p
826 }
827
828 #[test]
829 fn test_checkpoint_manager_new_creates_dir() {
830 let dir = tmp_dir("new_creates_dir");
831 let _mgr =
832 CheckpointManager::new(&dir, 3, CheckpointFormat::Text).expect("manager creation");
833 assert!(dir.exists(), "directory should have been created");
834 let _ = std::fs::remove_dir_all(&dir);
835 }
836
837 #[test]
838 fn test_checkpoint_manager_save_creates_file() {
839 let dir = tmp_dir("save_creates_file");
840 let mut mgr = CheckpointManager::new(&dir, 5, CheckpointFormat::Text).expect("manager");
841 let ckpt = make_ckpt(1, 0);
842 let path = mgr.save(&ckpt).expect("save");
843 assert!(path.exists(), "saved file should exist");
844 let _ = std::fs::remove_dir_all(&dir);
845 }
846
847 #[test]
848 fn test_checkpoint_manager_load_latest() {
849 let dir = tmp_dir("load_latest");
850 let mut mgr = CheckpointManager::new(&dir, 5, CheckpointFormat::Text).expect("manager");
851 let ckpt = make_ckpt(7, 1);
852 mgr.save(&ckpt).expect("save");
853 let loaded = mgr.load_latest().expect("load_latest");
854 assert_eq!(loaded.step, 7);
855 let _ = std::fs::remove_dir_all(&dir);
856 }
857
858 #[test]
859 fn test_checkpoint_manager_list() {
860 let dir = tmp_dir("list");
861 let mut mgr = CheckpointManager::new(&dir, 5, CheckpointFormat::Text).expect("manager");
862 mgr.save(&make_ckpt(10, 0)).expect("save 1");
863 mgr.save(&make_ckpt(20, 1)).expect("save 2");
864 let list = mgr.list();
865 assert_eq!(list.len(), 2);
866 let steps: Vec<u64> = list.iter().map(|(s, _)| *s).collect();
867 assert!(steps.contains(&10));
868 assert!(steps.contains(&20));
869 let _ = std::fs::remove_dir_all(&dir);
870 }
871
872 #[test]
873 fn test_checkpoint_manager_max_to_keep() {
874 let dir = tmp_dir("max_to_keep");
875 let mut mgr = CheckpointManager::new(&dir, 3, CheckpointFormat::Text).expect("manager");
876 for step in 0..5_u64 {
877 mgr.save(&make_ckpt(step * 10, step as u32)).expect("save");
878 }
879 assert_eq!(mgr.count(), 3, "only last 3 should be retained");
880 let steps: Vec<u64> = mgr.list().iter().map(|(s, _)| *s).collect();
881 assert!(steps.contains(&20));
882 assert!(steps.contains(&30));
883 assert!(steps.contains(&40));
884 let _ = std::fs::remove_dir_all(&dir);
885 }
886
887 #[test]
888 fn test_checkpoint_manager_load_at_step() {
889 let dir = tmp_dir("load_at_step");
890 let mut mgr = CheckpointManager::new(&dir, 5, CheckpointFormat::Binary).expect("manager");
891 mgr.save(&make_ckpt(5, 0)).expect("save");
892 mgr.save(&make_ckpt(10, 1)).expect("save");
893 let loaded = mgr.load_at_step(5).expect("load step 5");
894 assert_eq!(loaded.step, 5);
895 let _ = std::fs::remove_dir_all(&dir);
896 }
897
898 #[test]
899 fn test_checkpoint_manager_no_checkpoints() {
900 let dir = tmp_dir("no_checkpoints");
901 let mgr = CheckpointManager::new(&dir, 3, CheckpointFormat::Text).expect("manager");
902 let result = mgr.load_latest();
903 assert!(
904 matches!(result, Err(CheckpointError::NoCheckpointsAvailable)),
905 "expected NoCheckpointsAvailable, got {result:?}"
906 );
907 let _ = std::fs::remove_dir_all(&dir);
908 }
909
910 #[test]
913 fn test_loss_tracker_moving_average() {
914 let mut tracker = LossTracker::new(5);
915 tracker.push(1.0);
916 tracker.push(2.0);
917 tracker.push(3.0);
918 let avg = tracker.moving_average().expect("average");
919 let diff = (avg - 2.0).abs();
920 assert!(diff < 1e-12, "expected 2.0, got {avg}");
921 }
922
923 #[test]
924 fn test_loss_tracker_min_max() {
925 let mut tracker = LossTracker::new(10);
926 for v in [5.0, 1.0, 8.0, 3.0_f64] {
927 tracker.push(v);
928 }
929 assert!((tracker.min().expect("min") - 1.0).abs() < 1e-12);
930 assert!((tracker.max().expect("max") - 8.0).abs() < 1e-12);
931 }
932
933 #[test]
934 fn test_loss_tracker_is_improving_true() {
935 let mut tracker = LossTracker::new(10);
936 for v in [5.0, 4.8, 4.7, 4.9_f64] {
938 tracker.push(v);
939 }
940 assert!(
943 tracker.is_improving(2),
944 "expected improving with decreasing loss"
945 );
946 }
947
948 #[test]
949 fn test_loss_tracker_is_improving_false() {
950 let mut tracker = LossTracker::new(10);
951 for v in [1.0, 2.0, 3.0, 4.0_f64] {
953 tracker.push(v);
954 }
955 assert!(
958 !tracker.is_improving(2),
959 "expected not improving with increasing loss"
960 );
961 }
962
963 #[test]
966 fn test_checkpoint_error_display() {
967 let variants: Vec<CheckpointError> = vec![
968 CheckpointError::IoError("test io".into()),
969 CheckpointError::SerializationError("test ser".into()),
970 CheckpointError::DeserializationError("test deser".into()),
971 CheckpointError::CheckpointNotFound { step: 42 },
972 CheckpointError::NoCheckpointsAvailable,
973 CheckpointError::InvalidFormat("bad".into()),
974 CheckpointError::DirectoryCreationFailed("dir".into()),
975 ];
976 for err in &variants {
977 let s = err.to_string();
978 assert!(
979 !s.is_empty(),
980 "display output should not be empty for {err:?}"
981 );
982 }
983 }
984}