tiny_recursive_rs/training/
checkpoint.rs1use std::path::Path;
3use std::collections::HashMap;
4use candle_core::{Result, Tensor, Device, DType};
5use safetensors::tensor::{SafeTensors, TensorView};
6
7#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
9pub struct CheckpointMetadata {
10 pub step: usize,
12 pub lr: f64,
14 pub loss: Option<f64>,
16 pub config: Option<String>,
18}
19
20impl Default for CheckpointMetadata {
21 fn default() -> Self {
22 Self {
23 step: 0,
24 lr: 0.0,
25 loss: None,
26 config: None,
27 }
28 }
29}
30
31pub struct Checkpoint {
33 pub tensors: HashMap<String, Tensor>,
35 pub metadata: CheckpointMetadata,
37}
38
39impl Checkpoint {
40 pub fn new(tensors: HashMap<String, Tensor>, metadata: CheckpointMetadata) -> Self {
42 Self { tensors, metadata }
43 }
44
45 pub fn save<P: AsRef<Path>>(&self, path: P) -> crate::Result<()> {
53 let metadata_json = serde_json::to_string(&self.metadata)
58 .map_err(|e| crate::TRMError::Io(std::io::Error::new(
59 std::io::ErrorKind::Other,
60 e.to_string(),
61 )))?;
62
63 std::fs::write(path.as_ref(), metadata_json.as_bytes())?;
64
65 Ok(())
72 }
73
74 pub fn load<P: AsRef<Path>>(path: P, device: &Device) -> crate::Result<Self> {
83 let data = std::fs::read(path.as_ref())?;
85
86 let metadata: CheckpointMetadata = serde_json::from_slice(&data).unwrap_or_default();
93
94 Ok(Self {
95 tensors: HashMap::new(),
96 metadata,
97 })
98 }
99
100 pub fn load_weights<P: AsRef<Path>>(
109 path: P,
110 device: &Device,
111 ) -> crate::Result<HashMap<String, Tensor>> {
112 let checkpoint = Self::load(path, device)?;
113 Ok(checkpoint.tensors)
114 }
115}
116
117pub fn save_checkpoint<P: AsRef<Path>>(
124 params: HashMap<String, Tensor>,
125 path: P,
126 metadata: CheckpointMetadata,
127) -> crate::Result<()> {
128 let checkpoint = Checkpoint::new(params, metadata);
129 checkpoint.save(path)
130}
131
132pub fn load_checkpoint<P: AsRef<Path>>(
141 path: P,
142 device: &Device,
143) -> crate::Result<HashMap<String, Tensor>> {
144 Checkpoint::load_weights(path, device)
145}
146
147#[cfg(test)]
148mod tests {
149 use super::*;
150 use std::fs;
151
152 #[test]
153 fn test_checkpoint_metadata() {
154 let metadata = CheckpointMetadata {
155 step: 1000,
156 lr: 0.001,
157 loss: Some(0.5),
158 config: Some("{}".to_string()),
159 };
160
161 assert_eq!(metadata.step, 1000);
162 assert_eq!(metadata.lr, 0.001);
163 }
164
165 #[test]
166 fn test_checkpoint_creation() -> Result<()> {
167 let device = Device::Cpu;
168 let mut tensors = HashMap::new();
169 tensors.insert(
170 "weight".to_string(),
171 Tensor::ones((10, 10), DType::F32, &device)?,
172 );
173
174 let metadata = CheckpointMetadata::default();
175 let checkpoint = Checkpoint::new(tensors, metadata);
176
177 assert_eq!(checkpoint.tensors.len(), 1);
178
179 Ok(())
180 }
181
182 #[test]
183 fn test_save_load_checkpoint() -> Result<()> {
184 let device = Device::Cpu;
185 let mut tensors = HashMap::new();
186 tensors.insert(
187 "weight".to_string(),
188 Tensor::ones((5, 5), DType::F32, &device)?,
189 );
190
191 let metadata = CheckpointMetadata {
192 step: 500,
193 lr: 0.0005,
194 loss: Some(0.25),
195 config: None,
196 };
197
198 let temp_path = std::path::Path::new("test_checkpoint.safetensors");
199
200 let result = save_checkpoint(tensors, temp_path, metadata.clone());
202
203 if temp_path.exists() {
205 fs::remove_file(temp_path).ok();
206 }
207
208 Ok(())
209 }
210}