tsai_models/
checkpoint.rs1use std::path::Path;
34
35use burn::module::Module;
36use burn::prelude::*;
37use burn::record::{FullPrecisionSettings, NamedMpkFileRecorder, Recorder};
38use serde::{de::DeserializeOwned, Serialize};
39
40#[derive(Debug, Clone, Copy, PartialEq, Eq)]
42pub enum CheckpointFormat {
43 MessagePack,
45 NamedMessagePack,
47}
48
49#[derive(Debug, Clone, Copy, PartialEq, Eq)]
51pub enum CheckpointPrecision {
52 Full,
54 Half,
56}
57
58pub fn save_model<B, M>(model: &M, path: impl AsRef<Path>, format: CheckpointFormat) -> Result<()>
70where
71 B: Backend,
72 M: Module<B>,
73 M::Record: Serialize,
74{
75 let path = path.as_ref();
76 let record = model.clone().into_record();
77
78 match format {
79 CheckpointFormat::MessagePack | CheckpointFormat::NamedMessagePack => {
80 let recorder = NamedMpkFileRecorder::<FullPrecisionSettings>::new();
81 recorder
82 .record(record, path.to_path_buf())
83 .map_err(|e| CheckpointError::Save(e.to_string()))?;
84 }
85 }
86
87 Ok(())
88}
89
90pub fn load_record<B, M>(
102 path: impl AsRef<Path>,
103 format: CheckpointFormat,
104 device: &B::Device,
105) -> Result<M::Record>
106where
107 B: Backend,
108 M: Module<B>,
109 M::Record: DeserializeOwned,
110{
111 let path = path.as_ref();
112
113 let record = match format {
114 CheckpointFormat::MessagePack | CheckpointFormat::NamedMessagePack => {
115 let recorder = NamedMpkFileRecorder::<FullPrecisionSettings>::new();
116 recorder
117 .load(path.to_path_buf(), device)
118 .map_err(|e| CheckpointError::Load(e.to_string()))?
119 }
120 };
121
122 Ok(record)
123}
124
125#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
127pub struct CheckpointMetadata {
128 pub arch: String,
130 pub config_json: String,
132 pub epoch: Option<usize>,
134 pub val_loss: Option<f32>,
136 pub val_acc: Option<f32>,
138 pub extra: std::collections::HashMap<String, String>,
140}
141
142impl CheckpointMetadata {
143 pub fn new(arch: impl Into<String>) -> Self {
145 Self {
146 arch: arch.into(),
147 config_json: String::new(),
148 epoch: None,
149 val_loss: None,
150 val_acc: None,
151 extra: std::collections::HashMap::new(),
152 }
153 }
154
155 #[must_use]
157 pub fn with_config<C: Serialize>(mut self, config: &C) -> Self {
158 self.config_json = serde_json::to_string(config).unwrap_or_default();
159 self
160 }
161
162 #[must_use]
164 pub fn with_epoch(mut self, epoch: usize) -> Self {
165 self.epoch = Some(epoch);
166 self
167 }
168
169 #[must_use]
171 pub fn with_val_loss(mut self, loss: f32) -> Self {
172 self.val_loss = Some(loss);
173 self
174 }
175
176 #[must_use]
178 pub fn with_val_acc(mut self, acc: f32) -> Self {
179 self.val_acc = Some(acc);
180 self
181 }
182
183 #[must_use]
185 pub fn with_extra(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
186 self.extra.insert(key.into(), value.into());
187 self
188 }
189
190 pub fn save(&self, path: impl AsRef<Path>) -> Result<()> {
192 let json = serde_json::to_string_pretty(self)
193 .map_err(|e| CheckpointError::Save(e.to_string()))?;
194 std::fs::write(path, json).map_err(|e| CheckpointError::Save(e.to_string()))?;
195 Ok(())
196 }
197
198 pub fn load(path: impl AsRef<Path>) -> Result<Self> {
200 let json =
201 std::fs::read_to_string(path).map_err(|e| CheckpointError::Load(e.to_string()))?;
202 serde_json::from_str(&json).map_err(|e| CheckpointError::Load(e.to_string()))
203 }
204}
205
206pub type Result<T> = std::result::Result<T, CheckpointError>;
208
209#[derive(Debug, thiserror::Error)]
211pub enum CheckpointError {
212 #[error("Failed to save checkpoint: {0}")]
214 Save(String),
215
216 #[error("Failed to load checkpoint: {0}")]
218 Load(String),
219
220 #[error("Invalid checkpoint format: {0}")]
222 InvalidFormat(String),
223}
224
225pub trait ModelCheckpoint<B: Backend>: Module<B> {
227 fn save_checkpoint(&self, path: impl AsRef<Path>) -> Result<()>
229 where
230 Self::Record: Serialize,
231 {
232 save_model::<B, Self>(self, path, CheckpointFormat::NamedMessagePack)
233 }
234
235 fn load_checkpoint(&self, path: impl AsRef<Path>, device: &B::Device) -> Result<Self>
237 where
238 Self: Sized,
239 Self::Record: DeserializeOwned,
240 {
241 let record = load_record::<B, Self>(path, CheckpointFormat::NamedMessagePack, device)?;
242 Ok(self.clone().load_record(record))
243 }
244}
245
246impl<B: Backend, M: Module<B>> ModelCheckpoint<B> for M {}
248
249#[cfg(test)]
250mod tests {
251 use super::*;
252
253 #[test]
254 fn test_checkpoint_metadata() {
255 let meta = CheckpointMetadata::new("InceptionTimePlus")
256 .with_epoch(10)
257 .with_val_loss(0.25)
258 .with_val_acc(0.92)
259 .with_extra("dataset", "ECG200");
260
261 assert_eq!(meta.arch, "InceptionTimePlus");
262 assert_eq!(meta.epoch, Some(10));
263 assert_eq!(meta.val_loss, Some(0.25));
264 assert_eq!(meta.val_acc, Some(0.92));
265 assert_eq!(meta.extra.get("dataset"), Some(&"ECG200".to_string()));
266 }
267
268 #[test]
269 fn test_checkpoint_format() {
270 assert_eq!(CheckpointFormat::MessagePack, CheckpointFormat::MessagePack);
271 assert_ne!(
272 CheckpointFormat::MessagePack,
273 CheckpointFormat::NamedMessagePack
274 );
275 }
276
277 #[test]
278 fn test_checkpoint_precision() {
279 assert_eq!(CheckpointPrecision::Full, CheckpointPrecision::Full);
280 assert_ne!(CheckpointPrecision::Full, CheckpointPrecision::Half);
281 }
282}