Skip to main content

tensorlogic_train/
checkpoint.rs

1//! Optimizer checkpointing: save/load optimizer state (momentum buffers, step counts, etc.)
2//!
3//! Provides [`OptimizerCheckpoint`], [`CheckpointManager`], and [`LossTracker`] for
4//! persisting and restoring optimizer state during training.
5
6use std::collections::{HashMap, VecDeque};
7use std::path::{Path, PathBuf};
8
9// ---------------------------------------------------------------------------
10// Core data types
11// ---------------------------------------------------------------------------
12
13/// Per-parameter optimizer state (moment vectors, step counter, shape).
14#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
15pub struct ParamState {
16    pub name: String,
17    /// First moment estimate: velocity for SGD-momentum, m for Adam.
18    pub first_moment: Vec<f64>,
19    /// Second moment estimate: v for Adam; empty for SGD.
20    pub second_moment: Vec<f64>,
21    pub step: u64,
22    pub shape: Vec<usize>,
23}
24
25impl ParamState {
26    /// Construct a new [`ParamState`].
27    pub fn new(
28        name: impl Into<String>,
29        first_moment: Vec<f64>,
30        second_moment: Vec<f64>,
31        step: u64,
32        shape: Vec<usize>,
33    ) -> Self {
34        Self {
35            name: name.into(),
36            first_moment,
37            second_moment,
38            step,
39            shape,
40        }
41    }
42}
43
44/// Metadata attached to a checkpoint (loss values, extra annotations).
45#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
46pub struct CheckpointMetadata {
47    pub created_at_step: u64,
48    pub loss: Option<f64>,
49    pub val_loss: Option<f64>,
50    pub extra: HashMap<String, String>,
51}
52
53impl CheckpointMetadata {
54    /// Construct metadata for a given step.
55    pub fn new(created_at_step: u64) -> Self {
56        Self {
57            created_at_step,
58            loss: None,
59            val_loss: None,
60            extra: HashMap::new(),
61        }
62    }
63}
64
65/// Serialisable snapshot of an optimizer's full training state.
66#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
67pub struct OptimizerCheckpoint {
68    pub step: u64,
69    pub epoch: u32,
70    pub optimizer_name: String,
71    /// Parameter name → per-parameter state.
72    pub param_states: HashMap<String, ParamState>,
73    pub hyperparams: HashMap<String, f64>,
74    pub metadata: CheckpointMetadata,
75}
76
77impl OptimizerCheckpoint {
78    /// Create an empty checkpoint for the given optimizer at `step`/`epoch`.
79    pub fn new(optimizer_name: impl Into<String>, step: u64, epoch: u32) -> Self {
80        Self {
81            step,
82            epoch,
83            optimizer_name: optimizer_name.into(),
84            param_states: HashMap::new(),
85            hyperparams: HashMap::new(),
86            metadata: CheckpointMetadata::new(step),
87        }
88    }
89
90    /// Insert or replace the state for one parameter.
91    pub fn add_param_state(&mut self, name: impl Into<String>, state: ParamState) {
92        self.param_states.insert(name.into(), state);
93    }
94
95    /// Record a scalar hyper-parameter (learning rate, beta1, etc.).
96    pub fn set_hyperparam(&mut self, key: impl Into<String>, value: f64) {
97        self.hyperparams.insert(key.into(), value);
98    }
99
100    /// Retrieve a previously recorded hyper-parameter value.
101    pub fn get_hyperparam(&self, key: &str) -> Option<f64> {
102        self.hyperparams.get(key).copied()
103    }
104
105    /// Number of parameters stored in this checkpoint.
106    pub fn num_params(&self) -> usize {
107        self.param_states.len()
108    }
109
110    /// Total number of scalar elements across all first-moment vectors.
111    pub fn total_elements(&self) -> usize {
112        self.param_states
113            .values()
114            .map(|ps| ps.first_moment.len())
115            .sum()
116    }
117}
118
119// ---------------------------------------------------------------------------
120// Serialization format
121// ---------------------------------------------------------------------------
122
123/// Wire format used when writing checkpoints to disk.
124#[derive(Debug, Clone)]
125pub enum CheckpointFormat {
126    /// Custom binary envelope: magic bytes `TLCK` + u32 version + JSON payload.
127    Binary,
128    /// Human-readable `key=value` text with `\n---\n` section separators.
129    Text,
130}
131
132impl CheckpointFormat {
133    fn file_extension(&self) -> &'static str {
134        match self {
135            CheckpointFormat::Binary => "tlck",
136            CheckpointFormat::Text => "tlckt",
137        }
138    }
139}
140
141// ---------------------------------------------------------------------------
142// CheckpointError
143// ---------------------------------------------------------------------------
144
145/// Errors that can occur while managing checkpoints.
146#[derive(Debug, Clone)]
147pub enum CheckpointError {
148    IoError(String),
149    SerializationError(String),
150    DeserializationError(String),
151    CheckpointNotFound { step: u64 },
152    NoCheckpointsAvailable,
153    InvalidFormat(String),
154    DirectoryCreationFailed(String),
155}
156
157impl std::fmt::Display for CheckpointError {
158    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
159        match self {
160            CheckpointError::IoError(msg) => write!(f, "IO error: {msg}"),
161            CheckpointError::SerializationError(msg) => {
162                write!(f, "Serialization error: {msg}")
163            }
164            CheckpointError::DeserializationError(msg) => {
165                write!(f, "Deserialization error: {msg}")
166            }
167            CheckpointError::CheckpointNotFound { step } => {
168                write!(f, "Checkpoint not found for step {step}")
169            }
170            CheckpointError::NoCheckpointsAvailable => {
171                write!(f, "No checkpoints are available")
172            }
173            CheckpointError::InvalidFormat(msg) => write!(f, "Invalid format: {msg}"),
174            CheckpointError::DirectoryCreationFailed(msg) => {
175                write!(f, "Directory creation failed: {msg}")
176            }
177        }
178    }
179}
180
181impl std::error::Error for CheckpointError {}
182
183// ---------------------------------------------------------------------------
184// Serialization helpers (Text format)
185// ---------------------------------------------------------------------------
186
187/// Encode a `f64` slice as a comma-separated string.
188fn encode_f64_slice(values: &[f64]) -> String {
189    values
190        .iter()
191        .map(|v| v.to_string())
192        .collect::<Vec<_>>()
193        .join(",")
194}
195
196/// Decode a comma-separated string into `Vec<f64>`.
197fn decode_f64_slice(s: &str) -> Result<Vec<f64>, CheckpointError> {
198    if s.is_empty() {
199        return Ok(Vec::new());
200    }
201    s.split(',')
202        .map(|tok| {
203            tok.trim()
204                .parse::<f64>()
205                .map_err(|e| CheckpointError::DeserializationError(format!("f64 parse: {e}")))
206        })
207        .collect()
208}
209
210/// Encode a `usize` slice as a comma-separated string.
211fn encode_usize_slice(values: &[usize]) -> String {
212    values
213        .iter()
214        .map(|v| v.to_string())
215        .collect::<Vec<_>>()
216        .join(",")
217}
218
219/// Decode a comma-separated string into `Vec<usize>`.
220fn decode_usize_slice(s: &str) -> Result<Vec<usize>, CheckpointError> {
221    if s.is_empty() {
222        return Ok(Vec::new());
223    }
224    s.split(',')
225        .map(|tok| {
226            tok.trim()
227                .parse::<usize>()
228                .map_err(|e| CheckpointError::DeserializationError(format!("usize parse: {e}")))
229        })
230        .collect()
231}
232
233/// Serialize [`OptimizerCheckpoint`] to the `Text` format.
234fn serialize_text(ckpt: &OptimizerCheckpoint) -> Vec<u8> {
235    let mut out = String::new();
236
237    // --- header section ---
238    out.push_str("section=header\n");
239    out.push_str(&format!("step={}\n", ckpt.step));
240    out.push_str(&format!("epoch={}\n", ckpt.epoch));
241    out.push_str(&format!("optimizer_name={}\n", ckpt.optimizer_name));
242    out.push_str(&format!(
243        "created_at_step={}\n",
244        ckpt.metadata.created_at_step
245    ));
246    if let Some(loss) = ckpt.metadata.loss {
247        out.push_str(&format!("loss={loss}\n"));
248    }
249    if let Some(val_loss) = ckpt.metadata.val_loss {
250        out.push_str(&format!("val_loss={val_loss}\n"));
251    }
252    for (k, v) in &ckpt.metadata.extra {
253        out.push_str(&format!("extra.{k}={v}\n"));
254    }
255
256    out.push_str("\n---\n");
257
258    // --- hyperparams section ---
259    out.push_str("section=hyperparams\n");
260    for (k, v) in &ckpt.hyperparams {
261        out.push_str(&format!("hp.{k}={v}\n"));
262    }
263
264    out.push_str("\n---\n");
265
266    // --- param_states section ---
267    out.push_str("section=param_states\n");
268    for (param_name, ps) in &ckpt.param_states {
269        out.push_str(&format!("param.name={param_name}\n"));
270        out.push_str(&format!(
271            "param.first_moment={}\n",
272            encode_f64_slice(&ps.first_moment)
273        ));
274        out.push_str(&format!(
275            "param.second_moment={}\n",
276            encode_f64_slice(&ps.second_moment)
277        ));
278        out.push_str(&format!("param.step={}\n", ps.step));
279        out.push_str(&format!("param.shape={}\n", encode_usize_slice(&ps.shape)));
280        out.push_str("param.end\n");
281    }
282
283    out.into_bytes()
284}
285
286/// Deserialize [`OptimizerCheckpoint`] from the `Text` format.
287fn deserialize_text(bytes: &[u8]) -> Result<OptimizerCheckpoint, CheckpointError> {
288    let text = std::str::from_utf8(bytes)
289        .map_err(|e| CheckpointError::DeserializationError(format!("UTF-8: {e}")))?;
290
291    let mut step: Option<u64> = None;
292    let mut epoch: Option<u32> = None;
293    let mut optimizer_name: Option<String> = None;
294    let mut created_at_step: u64 = 0;
295    let mut loss: Option<f64> = None;
296    let mut val_loss: Option<f64> = None;
297    let mut extra: HashMap<String, String> = HashMap::new();
298    let mut hyperparams: HashMap<String, f64> = HashMap::new();
299    let mut param_states: HashMap<String, ParamState> = HashMap::new();
300
301    // Working state for the currently-open param block.
302    let mut cur_name: Option<String> = None;
303    let mut cur_first: Vec<f64> = Vec::new();
304    let mut cur_second: Vec<f64> = Vec::new();
305    let mut cur_step: u64 = 0;
306    let mut cur_shape: Vec<usize> = Vec::new();
307
308    for raw_line in text.lines() {
309        let line = raw_line.trim();
310        if line.is_empty() || line == "---" {
311            continue;
312        }
313        if line.starts_with("section=") {
314            continue;
315        }
316        if line == "param.end" {
317            if let Some(name) = cur_name.take() {
318                param_states.insert(
319                    name.clone(),
320                    ParamState {
321                        name,
322                        first_moment: std::mem::take(&mut cur_first),
323                        second_moment: std::mem::take(&mut cur_second),
324                        step: cur_step,
325                        shape: std::mem::take(&mut cur_shape),
326                    },
327                );
328            }
329            cur_step = 0;
330            continue;
331        }
332
333        let (key, value) = line.split_once('=').ok_or_else(|| {
334            CheckpointError::DeserializationError(format!("Missing '=' in line: {line}"))
335        })?;
336
337        match key {
338            "step" => {
339                step =
340                    Some(value.parse::<u64>().map_err(|e| {
341                        CheckpointError::DeserializationError(format!("step: {e}"))
342                    })?);
343            }
344            "epoch" => {
345                epoch =
346                    Some(value.parse::<u32>().map_err(|e| {
347                        CheckpointError::DeserializationError(format!("epoch: {e}"))
348                    })?);
349            }
350            "optimizer_name" => {
351                optimizer_name = Some(value.to_owned());
352            }
353            "created_at_step" => {
354                created_at_step = value.parse::<u64>().map_err(|e| {
355                    CheckpointError::DeserializationError(format!("created_at_step: {e}"))
356                })?;
357            }
358            "loss" => {
359                loss =
360                    Some(value.parse::<f64>().map_err(|e| {
361                        CheckpointError::DeserializationError(format!("loss: {e}"))
362                    })?);
363            }
364            "val_loss" => {
365                val_loss = Some(value.parse::<f64>().map_err(|e| {
366                    CheckpointError::DeserializationError(format!("val_loss: {e}"))
367                })?);
368            }
369            "param.name" => {
370                cur_name = Some(value.to_owned());
371            }
372            "param.first_moment" => {
373                cur_first = decode_f64_slice(value)?;
374            }
375            "param.second_moment" => {
376                cur_second = decode_f64_slice(value)?;
377            }
378            "param.step" => {
379                cur_step = value.parse::<u64>().map_err(|e| {
380                    CheckpointError::DeserializationError(format!("param.step: {e}"))
381                })?;
382            }
383            "param.shape" => {
384                cur_shape = decode_usize_slice(value)?;
385            }
386            other if other.starts_with("hp.") => {
387                let hp_key = other.trim_start_matches("hp.");
388                let hp_val = value.parse::<f64>().map_err(|e| {
389                    CheckpointError::DeserializationError(format!("hyperparam {hp_key}: {e}"))
390                })?;
391                hyperparams.insert(hp_key.to_owned(), hp_val);
392            }
393            other if other.starts_with("extra.") => {
394                let ex_key = other.trim_start_matches("extra.");
395                extra.insert(ex_key.to_owned(), value.to_owned());
396            }
397            _ => {} // tolerate unknown fields for forward-compat
398        }
399    }
400
401    let step =
402        step.ok_or_else(|| CheckpointError::DeserializationError("missing field: step".into()))?;
403    let epoch = epoch
404        .ok_or_else(|| CheckpointError::DeserializationError("missing field: epoch".into()))?;
405    let optimizer_name = optimizer_name.ok_or_else(|| {
406        CheckpointError::DeserializationError("missing field: optimizer_name".into())
407    })?;
408
409    Ok(OptimizerCheckpoint {
410        step,
411        epoch,
412        optimizer_name,
413        param_states,
414        hyperparams,
415        metadata: CheckpointMetadata {
416            created_at_step,
417            loss,
418            val_loss,
419            extra,
420        },
421    })
422}
423
424// Magic bytes that begin every Binary checkpoint file.
425const BINARY_MAGIC: [u8; 4] = [0x54, 0x4C, 0x43, 0x4B]; // "TLCK"
426const BINARY_VERSION: u32 = 1;
427
428/// Serialize an [`OptimizerCheckpoint`] to bytes using the chosen format.
429pub fn serialize_checkpoint(
430    ckpt: &OptimizerCheckpoint,
431    format: CheckpointFormat,
432) -> Result<Vec<u8>, CheckpointError> {
433    match format {
434        CheckpointFormat::Text => Ok(serialize_text(ckpt)),
435        CheckpointFormat::Binary => {
436            // Payload: serde_json → UTF-8 bytes.
437            let json = serde_json::to_vec(ckpt)
438                .map_err(|e| CheckpointError::SerializationError(format!("JSON: {e}")))?;
439
440            // Envelope: magic(4) + version(4, BE) + payload_len(4, BE) + payload.
441            let payload_len = json.len() as u32;
442            let mut out = Vec::with_capacity(12 + json.len());
443            out.extend_from_slice(&BINARY_MAGIC);
444            out.extend_from_slice(&BINARY_VERSION.to_be_bytes());
445            out.extend_from_slice(&payload_len.to_be_bytes());
446            out.extend_from_slice(&json);
447            Ok(out)
448        }
449    }
450}
451
452/// Deserialize an [`OptimizerCheckpoint`] from bytes using the chosen format.
453pub fn deserialize_checkpoint(
454    bytes: &[u8],
455    format: CheckpointFormat,
456) -> Result<OptimizerCheckpoint, CheckpointError> {
457    match format {
458        CheckpointFormat::Text => deserialize_text(bytes),
459        CheckpointFormat::Binary => {
460            // Validate magic bytes.
461            if bytes.len() < 12 {
462                return Err(CheckpointError::InvalidFormat(
463                    "binary checkpoint too short".into(),
464                ));
465            }
466            if bytes[..4] != BINARY_MAGIC {
467                return Err(CheckpointError::InvalidFormat(
468                    "bad magic bytes — not a TLCK checkpoint".into(),
469                ));
470            }
471            let version = u32::from_be_bytes(
472                bytes[4..8]
473                    .try_into()
474                    .map_err(|_| CheckpointError::InvalidFormat("version bytes".into()))?,
475            );
476            if version != BINARY_VERSION {
477                return Err(CheckpointError::InvalidFormat(format!(
478                    "unsupported version {version}"
479                )));
480            }
481            let payload_len = u32::from_be_bytes(
482                bytes[8..12]
483                    .try_into()
484                    .map_err(|_| CheckpointError::InvalidFormat("length bytes".into()))?,
485            ) as usize;
486            let payload_end = 12 + payload_len;
487            if bytes.len() < payload_end {
488                return Err(CheckpointError::InvalidFormat(
489                    "truncated binary checkpoint".into(),
490                ));
491            }
492            let json = &bytes[12..payload_end];
493            serde_json::from_slice(json)
494                .map_err(|e| CheckpointError::DeserializationError(format!("JSON: {e}")))
495        }
496    }
497}
498
499// ---------------------------------------------------------------------------
500// CheckpointManager
501// ---------------------------------------------------------------------------
502
503/// Manages writing and reading checkpoints under a directory.
504///
505/// Keeps a rolling window of the most-recent `max_to_keep` checkpoints, deleting
506/// older files automatically after each save.
507pub struct CheckpointManager {
508    pub dir: PathBuf,
509    pub max_to_keep: usize,
510    pub format: CheckpointFormat,
511    /// Ordered list of saved checkpoint paths (oldest first).
512    saved: Vec<PathBuf>,
513}
514
515impl CheckpointManager {
516    /// Create a new manager, creating `dir` if it does not already exist.
517    pub fn new(
518        dir: impl AsRef<Path>,
519        max_to_keep: usize,
520        format: CheckpointFormat,
521    ) -> Result<Self, CheckpointError> {
522        let dir = dir.as_ref().to_path_buf();
523        std::fs::create_dir_all(&dir).map_err(|e| {
524            CheckpointError::DirectoryCreationFailed(format!("{}: {e}", dir.display()))
525        })?;
526        Ok(Self {
527            dir,
528            max_to_keep,
529            format,
530            saved: Vec::new(),
531        })
532    }
533
534    /// Compute the filename for a checkpoint at `step` with the given format.
535    fn checkpoint_filename(step: u64, format: &CheckpointFormat) -> String {
536        format!("ckpt-step-{:012}.{}", step, format.file_extension())
537    }
538
539    /// Save `ckpt` to disk and prune old checkpoints. Returns the saved path.
540    pub fn save(&mut self, ckpt: &OptimizerCheckpoint) -> Result<PathBuf, CheckpointError> {
541        let filename = Self::checkpoint_filename(ckpt.step, &self.format);
542        let path = self.dir.join(&filename);
543
544        let bytes = serialize_checkpoint(ckpt, self.format.clone())?;
545        std::fs::write(&path, &bytes)
546            .map_err(|e| CheckpointError::IoError(format!("write {}: {e}", path.display())))?;
547
548        self.saved.push(path.clone());
549        self.prune_old()?;
550        Ok(path)
551    }
552
553    /// Load the most recently saved checkpoint.
554    pub fn load_latest(&self) -> Result<OptimizerCheckpoint, CheckpointError> {
555        let path = self
556            .saved
557            .last()
558            .ok_or(CheckpointError::NoCheckpointsAvailable)?;
559        self.load_from_path(path)
560    }
561
562    /// Load the checkpoint saved at a specific training step.
563    pub fn load_at_step(&self, step: u64) -> Result<OptimizerCheckpoint, CheckpointError> {
564        let filename = Self::checkpoint_filename(step, &self.format);
565        let path = self.dir.join(&filename);
566        if !self.saved.iter().any(|p| p == &path) {
567            return Err(CheckpointError::CheckpointNotFound { step });
568        }
569        self.load_from_path(&path)
570    }
571
572    /// Return `(step, path)` pairs for all retained checkpoints.
573    pub fn list(&self) -> Vec<(u64, &Path)> {
574        self.saved
575            .iter()
576            .filter_map(|p| {
577                // Extract step from filename "ckpt-step-<12digits>.<ext>"
578                let stem = p.file_stem()?.to_str()?;
579                let step_str = stem.strip_prefix("ckpt-step-")?;
580                let step = step_str.parse::<u64>().ok()?;
581                Some((step, p.as_path()))
582            })
583            .collect()
584    }
585
586    /// Number of checkpoints currently retained.
587    pub fn count(&self) -> usize {
588        self.saved.len()
589    }
590
591    // --- private helpers ---
592
593    fn load_from_path(&self, path: &Path) -> Result<OptimizerCheckpoint, CheckpointError> {
594        let bytes = std::fs::read(path)
595            .map_err(|e| CheckpointError::IoError(format!("read {}: {e}", path.display())))?;
596        deserialize_checkpoint(&bytes, self.format.clone())
597    }
598
599    /// Delete checkpoints that exceed the `max_to_keep` rolling window.
600    fn prune_old(&mut self) -> Result<(), CheckpointError> {
601        while self.saved.len() > self.max_to_keep {
602            let oldest = self.saved.remove(0);
603            if oldest.exists() {
604                std::fs::remove_file(&oldest).map_err(|e| {
605                    CheckpointError::IoError(format!("delete {}: {e}", oldest.display()))
606                })?;
607            }
608        }
609        Ok(())
610    }
611}
612
613// ---------------------------------------------------------------------------
614// LossTracker
615// ---------------------------------------------------------------------------
616
617/// Rolling-window tracker for scalar loss values recorded during training.
618///
619/// Provides moving average, min/max, and a simple improvement check useful for
620/// early stopping decisions.
621#[derive(Debug, Clone)]
622pub struct LossTracker {
623    pub window_size: usize,
624    history: VecDeque<f64>,
625}
626
627impl LossTracker {
628    /// Create a new tracker with the given sliding window capacity.
629    pub fn new(window_size: usize) -> Self {
630        Self {
631            window_size,
632            history: VecDeque::with_capacity(window_size),
633        }
634    }
635
636    /// Record a new loss value, evicting the oldest if the window is full.
637    pub fn push(&mut self, loss: f64) {
638        if self.history.len() == self.window_size {
639            self.history.pop_front();
640        }
641        self.history.push_back(loss);
642    }
643
644    /// Arithmetic mean over the current window; `None` if empty.
645    pub fn moving_average(&self) -> Option<f64> {
646        if self.history.is_empty() {
647            return None;
648        }
649        let sum: f64 = self.history.iter().sum();
650        Some(sum / self.history.len() as f64)
651    }
652
653    /// Minimum value in the current window; `None` if empty.
654    pub fn min(&self) -> Option<f64> {
655        self.history.iter().copied().reduce(f64::min)
656    }
657
658    /// Maximum value in the current window; `None` if empty.
659    pub fn max(&self) -> Option<f64> {
660        self.history.iter().copied().reduce(f64::max)
661    }
662
663    /// Returns `true` when the minimum loss seen in the most-recent `patience`
664    /// values is strictly less than the minimum over the full window *excluding*
665    /// those recent values.  This captures "the model has improved recently".
666    ///
667    /// Returns `false` when there are not enough data points to compare.
668    pub fn is_improving(&self, patience: usize) -> bool {
669        if self.history.len() <= patience {
670            return false;
671        }
672        let split = self.history.len() - patience;
673        let older_min = self.history.iter().take(split).copied().reduce(f64::min);
674        let recent_min = self.history.iter().skip(split).copied().reduce(f64::min);
675        match (older_min, recent_min) {
676            (Some(old), Some(new)) => new < old,
677            _ => false,
678        }
679    }
680
681    /// Number of values currently held.
682    pub fn len(&self) -> usize {
683        self.history.len()
684    }
685
686    /// `true` when no values have been recorded yet.
687    pub fn is_empty(&self) -> bool {
688        self.history.is_empty()
689    }
690}
691
692// ---------------------------------------------------------------------------
693// Tests
694// ---------------------------------------------------------------------------
695
696#[cfg(test)]
697mod tests {
698    use super::*;
699
700    // --- helper: build a small checkpoint ---
701    fn make_ckpt(step: u64, epoch: u32) -> OptimizerCheckpoint {
702        let mut ckpt = OptimizerCheckpoint::new("adam", step, epoch);
703        ckpt.set_hyperparam("lr", 0.001);
704        ckpt.set_hyperparam("beta1", 0.9);
705        let ps = ParamState::new(
706            "layer0.weight",
707            vec![0.1, 0.2, 0.3],
708            vec![0.01, 0.02, 0.03],
709            step,
710            vec![3],
711        );
712        ckpt.add_param_state("layer0.weight", ps);
713        ckpt
714    }
715
716    // --- OptimizerCheckpoint ---
717
718    #[test]
719    fn test_optimizer_checkpoint_new() {
720        let ckpt = OptimizerCheckpoint::new("sgd", 42, 3);
721        assert_eq!(ckpt.step, 42);
722        assert_eq!(ckpt.epoch, 3);
723        assert_eq!(ckpt.optimizer_name, "sgd");
724    }
725
726    #[test]
727    fn test_add_param_state() {
728        let mut ckpt = OptimizerCheckpoint::new("adam", 0, 0);
729        assert_eq!(ckpt.num_params(), 0);
730        let ps = ParamState::new("w", vec![1.0], vec![], 0, vec![1]);
731        ckpt.add_param_state("w", ps);
732        assert_eq!(ckpt.num_params(), 1);
733    }
734
735    #[test]
736    fn test_set_get_hyperparam() {
737        let mut ckpt = OptimizerCheckpoint::new("adam", 0, 0);
738        ckpt.set_hyperparam("lr", 3e-4);
739        let retrieved = ckpt.get_hyperparam("lr");
740        assert!(retrieved.is_some());
741        let diff = (retrieved.unwrap_or(0.0) - 3e-4).abs();
742        assert!(diff < 1e-12, "hyperparam roundtrip mismatch");
743        assert!(ckpt.get_hyperparam("missing").is_none());
744    }
745
746    #[test]
747    fn test_total_elements() {
748        let mut ckpt = OptimizerCheckpoint::new("adam", 0, 0);
749        ckpt.add_param_state(
750            "a",
751            ParamState::new("a", vec![1.0, 2.0], vec![], 0, vec![2]),
752        );
753        ckpt.add_param_state(
754            "b",
755            ParamState::new("b", vec![3.0, 4.0, 5.0], vec![], 0, vec![3]),
756        );
757        assert_eq!(ckpt.total_elements(), 5);
758    }
759
760    // --- Text format serialization ---
761
762    #[test]
763    fn test_serialize_text_roundtrip() {
764        let ckpt = make_ckpt(100, 2);
765        let bytes = serialize_checkpoint(&ckpt, CheckpointFormat::Text).expect("serialize text");
766        let loaded =
767            deserialize_checkpoint(&bytes, CheckpointFormat::Text).expect("deserialize text");
768        assert_eq!(loaded.step, 100);
769        assert_eq!(loaded.epoch, 2);
770        assert_eq!(loaded.optimizer_name, "adam");
771    }
772
773    #[test]
774    fn test_serialize_text_param_states() {
775        let ckpt = make_ckpt(50, 1);
776        let bytes = serialize_checkpoint(&ckpt, CheckpointFormat::Text).expect("serialize");
777        let loaded = deserialize_checkpoint(&bytes, CheckpointFormat::Text).expect("deserialize");
778        assert_eq!(loaded.num_params(), 1);
779        let ps = loaded
780            .param_states
781            .get("layer0.weight")
782            .expect("param not found");
783        assert_eq!(ps.first_moment, vec![0.1, 0.2, 0.3]);
784        assert_eq!(ps.second_moment, vec![0.01, 0.02, 0.03]);
785        assert_eq!(ps.shape, vec![3]);
786    }
787
788    // --- Binary format serialization ---
789
790    #[test]
791    fn test_serialize_binary_roundtrip() {
792        let ckpt = make_ckpt(200, 5);
793        let bytes =
794            serialize_checkpoint(&ckpt, CheckpointFormat::Binary).expect("serialize binary");
795        // Verify magic header.
796        assert_eq!(&bytes[..4], &BINARY_MAGIC);
797        let loaded =
798            deserialize_checkpoint(&bytes, CheckpointFormat::Binary).expect("deserialize binary");
799        assert_eq!(loaded.step, 200);
800        assert_eq!(loaded.epoch, 5);
801        assert_eq!(loaded.optimizer_name, "adam");
802    }
803
804    #[test]
805    fn test_serialize_hyperparams_roundtrip() {
806        let mut ckpt = OptimizerCheckpoint::new("rmsprop", 10, 0);
807        ckpt.set_hyperparam("alpha", 0.99);
808        ckpt.set_hyperparam("eps", 1e-8);
809
810        for format in [CheckpointFormat::Text, CheckpointFormat::Binary] {
811            let bytes = serialize_checkpoint(&ckpt, format.clone()).expect("serialize");
812            let loaded = deserialize_checkpoint(&bytes, format).expect("deserialize");
813            let alpha = loaded.get_hyperparam("alpha").expect("alpha");
814            let eps = loaded.get_hyperparam("eps").expect("eps");
815            assert!((alpha - 0.99).abs() < 1e-12);
816            assert!((eps - 1e-8).abs() < 1e-20);
817        }
818    }
819
820    // --- CheckpointManager ---
821
822    fn tmp_dir(suffix: &str) -> PathBuf {
823        let mut p = std::env::temp_dir();
824        p.push(format!("tl_ckpt_test_{suffix}_{}", std::process::id()));
825        p
826    }
827
828    #[test]
829    fn test_checkpoint_manager_new_creates_dir() {
830        let dir = tmp_dir("new_creates_dir");
831        let _mgr =
832            CheckpointManager::new(&dir, 3, CheckpointFormat::Text).expect("manager creation");
833        assert!(dir.exists(), "directory should have been created");
834        let _ = std::fs::remove_dir_all(&dir);
835    }
836
837    #[test]
838    fn test_checkpoint_manager_save_creates_file() {
839        let dir = tmp_dir("save_creates_file");
840        let mut mgr = CheckpointManager::new(&dir, 5, CheckpointFormat::Text).expect("manager");
841        let ckpt = make_ckpt(1, 0);
842        let path = mgr.save(&ckpt).expect("save");
843        assert!(path.exists(), "saved file should exist");
844        let _ = std::fs::remove_dir_all(&dir);
845    }
846
847    #[test]
848    fn test_checkpoint_manager_load_latest() {
849        let dir = tmp_dir("load_latest");
850        let mut mgr = CheckpointManager::new(&dir, 5, CheckpointFormat::Text).expect("manager");
851        let ckpt = make_ckpt(7, 1);
852        mgr.save(&ckpt).expect("save");
853        let loaded = mgr.load_latest().expect("load_latest");
854        assert_eq!(loaded.step, 7);
855        let _ = std::fs::remove_dir_all(&dir);
856    }
857
858    #[test]
859    fn test_checkpoint_manager_list() {
860        let dir = tmp_dir("list");
861        let mut mgr = CheckpointManager::new(&dir, 5, CheckpointFormat::Text).expect("manager");
862        mgr.save(&make_ckpt(10, 0)).expect("save 1");
863        mgr.save(&make_ckpt(20, 1)).expect("save 2");
864        let list = mgr.list();
865        assert_eq!(list.len(), 2);
866        let steps: Vec<u64> = list.iter().map(|(s, _)| *s).collect();
867        assert!(steps.contains(&10));
868        assert!(steps.contains(&20));
869        let _ = std::fs::remove_dir_all(&dir);
870    }
871
872    #[test]
873    fn test_checkpoint_manager_max_to_keep() {
874        let dir = tmp_dir("max_to_keep");
875        let mut mgr = CheckpointManager::new(&dir, 3, CheckpointFormat::Text).expect("manager");
876        for step in 0..5_u64 {
877            mgr.save(&make_ckpt(step * 10, step as u32)).expect("save");
878        }
879        assert_eq!(mgr.count(), 3, "only last 3 should be retained");
880        let steps: Vec<u64> = mgr.list().iter().map(|(s, _)| *s).collect();
881        assert!(steps.contains(&20));
882        assert!(steps.contains(&30));
883        assert!(steps.contains(&40));
884        let _ = std::fs::remove_dir_all(&dir);
885    }
886
887    #[test]
888    fn test_checkpoint_manager_load_at_step() {
889        let dir = tmp_dir("load_at_step");
890        let mut mgr = CheckpointManager::new(&dir, 5, CheckpointFormat::Binary).expect("manager");
891        mgr.save(&make_ckpt(5, 0)).expect("save");
892        mgr.save(&make_ckpt(10, 1)).expect("save");
893        let loaded = mgr.load_at_step(5).expect("load step 5");
894        assert_eq!(loaded.step, 5);
895        let _ = std::fs::remove_dir_all(&dir);
896    }
897
898    #[test]
899    fn test_checkpoint_manager_no_checkpoints() {
900        let dir = tmp_dir("no_checkpoints");
901        let mgr = CheckpointManager::new(&dir, 3, CheckpointFormat::Text).expect("manager");
902        let result = mgr.load_latest();
903        assert!(
904            matches!(result, Err(CheckpointError::NoCheckpointsAvailable)),
905            "expected NoCheckpointsAvailable, got {result:?}"
906        );
907        let _ = std::fs::remove_dir_all(&dir);
908    }
909
910    // --- LossTracker ---
911
912    #[test]
913    fn test_loss_tracker_moving_average() {
914        let mut tracker = LossTracker::new(5);
915        tracker.push(1.0);
916        tracker.push(2.0);
917        tracker.push(3.0);
918        let avg = tracker.moving_average().expect("average");
919        let diff = (avg - 2.0).abs();
920        assert!(diff < 1e-12, "expected 2.0, got {avg}");
921    }
922
923    #[test]
924    fn test_loss_tracker_min_max() {
925        let mut tracker = LossTracker::new(10);
926        for v in [5.0, 1.0, 8.0, 3.0_f64] {
927            tracker.push(v);
928        }
929        assert!((tracker.min().expect("min") - 1.0).abs() < 1e-12);
930        assert!((tracker.max().expect("max") - 8.0).abs() < 1e-12);
931    }
932
933    #[test]
934    fn test_loss_tracker_is_improving_true() {
935        let mut tracker = LossTracker::new(10);
936        // Older values are high; recent values are lower → improving.
937        for v in [5.0, 4.8, 4.7, 4.9_f64] {
938            tracker.push(v);
939        }
940        // patience=2 → recent = [4.7, 4.9], older = [5.0, 4.8].
941        // recent min = 4.7 < older min = 4.8 → true.
942        assert!(
943            tracker.is_improving(2),
944            "expected improving with decreasing loss"
945        );
946    }
947
948    #[test]
949    fn test_loss_tracker_is_improving_false() {
950        let mut tracker = LossTracker::new(10);
951        // Loss is not decreasing.
952        for v in [1.0, 2.0, 3.0, 4.0_f64] {
953            tracker.push(v);
954        }
955        // patience=2 → recent = [3.0, 4.0], older = [1.0, 2.0].
956        // recent min = 3.0 > older min = 1.0 → false.
957        assert!(
958            !tracker.is_improving(2),
959            "expected not improving with increasing loss"
960        );
961    }
962
963    // --- CheckpointError Display ---
964
965    #[test]
966    fn test_checkpoint_error_display() {
967        let variants: Vec<CheckpointError> = vec![
968            CheckpointError::IoError("test io".into()),
969            CheckpointError::SerializationError("test ser".into()),
970            CheckpointError::DeserializationError("test deser".into()),
971            CheckpointError::CheckpointNotFound { step: 42 },
972            CheckpointError::NoCheckpointsAvailable,
973            CheckpointError::InvalidFormat("bad".into()),
974            CheckpointError::DirectoryCreationFailed("dir".into()),
975        ];
976        for err in &variants {
977            let s = err.to_string();
978            assert!(
979                !s.is_empty(),
980                "display output should not be empty for {err:?}"
981            );
982        }
983    }
984}