Skip to main content

tsai_models/
checkpoint.rs

1//! Model checkpointing and serialization utilities.
2//!
3//! Provides utilities for saving and loading model weights using Burn's record system.
4//!
5//! # Supported Formats
6//!
7//! - **MessagePack** (`*.mpk`): Fast binary format, good for local storage
8//! - **SafeTensors** (`*.safetensors`): Portable format, safe for sharing
9//! - **Named MessagePack** (`*.named.mpk`): Includes parameter names
10//!
11//! # Example
12//!
13//! ```rust,ignore
14//! use tsai_models::checkpoint::{save_model, load_model, CheckpointFormat};
15//! use tsai_models::InceptionTimePlusConfig;
16//!
17//! // Configure and create model
18//! let config = InceptionTimePlusConfig::new(1, 100, 5);
19//! let model = config.init::<NdArray>(&device);
20//!
21//! // Save model
22//! save_model(&model, "model.mpk", CheckpointFormat::MessagePack)?;
23//!
24//! // Load model
25//! let loaded = load_model::<_, InceptionTimePlus<_>>(
26//!     &config,
27//!     "model.mpk",
28//!     CheckpointFormat::MessagePack,
29//!     &device
30//! )?;
31//! ```
32
33use std::path::Path;
34
35use burn::module::Module;
36use burn::prelude::*;
37use burn::record::{FullPrecisionSettings, NamedMpkFileRecorder, Recorder};
38use serde::{de::DeserializeOwned, Serialize};
39
40/// Checkpoint format for model serialization.
41#[derive(Debug, Clone, Copy, PartialEq, Eq)]
42pub enum CheckpointFormat {
43    /// MessagePack binary format (fast, compact).
44    MessagePack,
45    /// Named MessagePack (includes parameter names).
46    NamedMessagePack,
47}
48
49/// Precision setting for checkpoints.
50#[derive(Debug, Clone, Copy, PartialEq, Eq)]
51pub enum CheckpointPrecision {
52    /// Full precision (f32).
53    Full,
54    /// Half precision (f16).
55    Half,
56}
57
58/// Save a model to a checkpoint file.
59///
60/// # Arguments
61///
62/// * `model` - The model to save
63/// * `path` - Output path
64/// * `format` - Checkpoint format
65///
66/// # Returns
67///
68/// Result indicating success or failure.
69pub fn save_model<B, M>(model: &M, path: impl AsRef<Path>, format: CheckpointFormat) -> Result<()>
70where
71    B: Backend,
72    M: Module<B>,
73    M::Record: Serialize,
74{
75    let path = path.as_ref();
76    let record = model.clone().into_record();
77
78    match format {
79        CheckpointFormat::MessagePack | CheckpointFormat::NamedMessagePack => {
80            let recorder = NamedMpkFileRecorder::<FullPrecisionSettings>::new();
81            recorder
82                .record(record, path.to_path_buf())
83                .map_err(|e| CheckpointError::Save(e.to_string()))?;
84        }
85    }
86
87    Ok(())
88}
89
90/// Load a model from a checkpoint file.
91///
92/// # Arguments
93///
94/// * `path` - Path to checkpoint
95/// * `format` - Checkpoint format
96/// * `device` - Device to load model onto
97///
98/// # Returns
99///
100/// The loaded model record.
101pub fn load_record<B, M>(
102    path: impl AsRef<Path>,
103    format: CheckpointFormat,
104    device: &B::Device,
105) -> Result<M::Record>
106where
107    B: Backend,
108    M: Module<B>,
109    M::Record: DeserializeOwned,
110{
111    let path = path.as_ref();
112
113    let record = match format {
114        CheckpointFormat::MessagePack | CheckpointFormat::NamedMessagePack => {
115            let recorder = NamedMpkFileRecorder::<FullPrecisionSettings>::new();
116            recorder
117                .load(path.to_path_buf(), device)
118                .map_err(|e| CheckpointError::Load(e.to_string()))?
119        }
120    };
121
122    Ok(record)
123}
124
125/// Model checkpoint metadata.
126#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
127pub struct CheckpointMetadata {
128    /// Model architecture name.
129    pub arch: String,
130    /// Model configuration as JSON.
131    pub config_json: String,
132    /// Training epoch (if applicable).
133    pub epoch: Option<usize>,
134    /// Validation loss (if applicable).
135    pub val_loss: Option<f32>,
136    /// Validation accuracy (if applicable).
137    pub val_acc: Option<f32>,
138    /// Additional metadata.
139    pub extra: std::collections::HashMap<String, String>,
140}
141
142impl CheckpointMetadata {
143    /// Create new metadata for a model.
144    pub fn new(arch: impl Into<String>) -> Self {
145        Self {
146            arch: arch.into(),
147            config_json: String::new(),
148            epoch: None,
149            val_loss: None,
150            val_acc: None,
151            extra: std::collections::HashMap::new(),
152        }
153    }
154
155    /// Set the config JSON.
156    #[must_use]
157    pub fn with_config<C: Serialize>(mut self, config: &C) -> Self {
158        self.config_json = serde_json::to_string(config).unwrap_or_default();
159        self
160    }
161
162    /// Set the training epoch.
163    #[must_use]
164    pub fn with_epoch(mut self, epoch: usize) -> Self {
165        self.epoch = Some(epoch);
166        self
167    }
168
169    /// Set the validation loss.
170    #[must_use]
171    pub fn with_val_loss(mut self, loss: f32) -> Self {
172        self.val_loss = Some(loss);
173        self
174    }
175
176    /// Set the validation accuracy.
177    #[must_use]
178    pub fn with_val_acc(mut self, acc: f32) -> Self {
179        self.val_acc = Some(acc);
180        self
181    }
182
183    /// Add extra metadata.
184    #[must_use]
185    pub fn with_extra(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
186        self.extra.insert(key.into(), value.into());
187        self
188    }
189
190    /// Save metadata to a JSON file.
191    pub fn save(&self, path: impl AsRef<Path>) -> Result<()> {
192        let json = serde_json::to_string_pretty(self)
193            .map_err(|e| CheckpointError::Save(e.to_string()))?;
194        std::fs::write(path, json).map_err(|e| CheckpointError::Save(e.to_string()))?;
195        Ok(())
196    }
197
198    /// Load metadata from a JSON file.
199    pub fn load(path: impl AsRef<Path>) -> Result<Self> {
200        let json =
201            std::fs::read_to_string(path).map_err(|e| CheckpointError::Load(e.to_string()))?;
202        serde_json::from_str(&json).map_err(|e| CheckpointError::Load(e.to_string()))
203    }
204}
205
206/// Result type for checkpoint operations.
207pub type Result<T> = std::result::Result<T, CheckpointError>;
208
209/// Checkpoint-related errors.
210#[derive(Debug, thiserror::Error)]
211pub enum CheckpointError {
212    /// Error saving checkpoint.
213    #[error("Failed to save checkpoint: {0}")]
214    Save(String),
215
216    /// Error loading checkpoint.
217    #[error("Failed to load checkpoint: {0}")]
218    Load(String),
219
220    /// Invalid format.
221    #[error("Invalid checkpoint format: {0}")]
222    InvalidFormat(String),
223}
224
225/// Extension trait for models to add checkpoint methods.
226pub trait ModelCheckpoint<B: Backend>: Module<B> {
227    /// Save the model to a checkpoint file.
228    fn save_checkpoint(&self, path: impl AsRef<Path>) -> Result<()>
229    where
230        Self::Record: Serialize,
231    {
232        save_model::<B, Self>(self, path, CheckpointFormat::NamedMessagePack)
233    }
234
235    /// Load model from a checkpoint into an existing model.
236    fn load_checkpoint(&self, path: impl AsRef<Path>, device: &B::Device) -> Result<Self>
237    where
238        Self: Sized,
239        Self::Record: DeserializeOwned,
240    {
241        let record = load_record::<B, Self>(path, CheckpointFormat::NamedMessagePack, device)?;
242        Ok(self.clone().load_record(record))
243    }
244}
245
246// Implement for all modules
247impl<B: Backend, M: Module<B>> ModelCheckpoint<B> for M {}
248
249#[cfg(test)]
250mod tests {
251    use super::*;
252
253    #[test]
254    fn test_checkpoint_metadata() {
255        let meta = CheckpointMetadata::new("InceptionTimePlus")
256            .with_epoch(10)
257            .with_val_loss(0.25)
258            .with_val_acc(0.92)
259            .with_extra("dataset", "ECG200");
260
261        assert_eq!(meta.arch, "InceptionTimePlus");
262        assert_eq!(meta.epoch, Some(10));
263        assert_eq!(meta.val_loss, Some(0.25));
264        assert_eq!(meta.val_acc, Some(0.92));
265        assert_eq!(meta.extra.get("dataset"), Some(&"ECG200".to_string()));
266    }
267
268    #[test]
269    fn test_checkpoint_format() {
270        assert_eq!(CheckpointFormat::MessagePack, CheckpointFormat::MessagePack);
271        assert_ne!(
272            CheckpointFormat::MessagePack,
273            CheckpointFormat::NamedMessagePack
274        );
275    }
276
277    #[test]
278    fn test_checkpoint_precision() {
279        assert_eq!(CheckpointPrecision::Full, CheckpointPrecision::Full);
280        assert_ne!(CheckpointPrecision::Full, CheckpointPrecision::Half);
281    }
282}