tiny_recursive_rs/training/
checkpoint.rs

1/// Model checkpointing with safetensors
2use std::path::Path;
3use std::collections::HashMap;
4use candle_core::{Result, Tensor, Device, DType};
5use safetensors::tensor::{SafeTensors, TensorView};
6
7/// Checkpoint metadata
8#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
9pub struct CheckpointMetadata {
10    /// Training step
11    pub step: usize,
12    /// Learning rate at checkpoint
13    pub lr: f64,
14    /// Loss at checkpoint
15    pub loss: Option<f64>,
16    /// Model configuration (as JSON string)
17    pub config: Option<String>,
18}
19
20impl Default for CheckpointMetadata {
21    fn default() -> Self {
22        Self {
23            step: 0,
24            lr: 0.0,
25            loss: None,
26            config: None,
27        }
28    }
29}
30
31/// Model checkpoint
32pub struct Checkpoint {
33    /// Model parameters
34    pub tensors: HashMap<String, Tensor>,
35    /// Metadata
36    pub metadata: CheckpointMetadata,
37}
38
39impl Checkpoint {
40    /// Create new checkpoint
41    pub fn new(tensors: HashMap<String, Tensor>, metadata: CheckpointMetadata) -> Self {
42        Self { tensors, metadata }
43    }
44
45    /// Save checkpoint to file
46    ///
47    /// # Arguments
48    /// * `path` - Path to save checkpoint
49    ///
50    /// # Returns
51    /// Result indicating success or error
52    pub fn save<P: AsRef<Path>>(&self, path: P) -> crate::Result<()> {
53        // Simplified checkpoint saving
54        // In a full implementation, this would use safetensors::serialize
55        // For now, just save metadata
56
57        let metadata_json = serde_json::to_string(&self.metadata)
58            .map_err(|e| crate::TRMError::Io(std::io::Error::new(
59                std::io::ErrorKind::Other,
60                e.to_string(),
61            )))?;
62
63        std::fs::write(path.as_ref(), metadata_json.as_bytes())?;
64
65        // TODO: Implement full safetensors serialization
66        // This would involve:
67        // 1. Collecting all tensor data
68        // 2. Creating proper TensorView objects
69        // 3. Using safetensors::serialize to write binary format
70
71        Ok(())
72    }
73
74    /// Load checkpoint from file
75    ///
76    /// # Arguments
77    /// * `path` - Path to checkpoint file
78    /// * `device` - Device to load tensors on
79    ///
80    /// # Returns
81    /// Loaded checkpoint
82    pub fn load<P: AsRef<Path>>(path: P, device: &Device) -> crate::Result<Self> {
83        // Read file
84        let data = std::fs::read(path.as_ref())?;
85
86        // This is a simplified placeholder
87        // Proper implementation would:
88        // 1. Parse safetensors format
89        // 2. Load tensors onto device
90        // 3. Extract metadata
91
92        let metadata: CheckpointMetadata = serde_json::from_slice(&data).unwrap_or_default();
93
94        Ok(Self {
95            tensors: HashMap::new(),
96            metadata,
97        })
98    }
99
100    /// Load only model weights from checkpoint
101    ///
102    /// # Arguments
103    /// * `path` - Path to checkpoint file
104    /// * `device` - Device to load tensors on
105    ///
106    /// # Returns
107    /// HashMap of parameter name to tensor
108    pub fn load_weights<P: AsRef<Path>>(
109        path: P,
110        device: &Device,
111    ) -> crate::Result<HashMap<String, Tensor>> {
112        let checkpoint = Self::load(path, device)?;
113        Ok(checkpoint.tensors)
114    }
115}
116
117/// Save model parameters to checkpoint
118///
119/// # Arguments
120/// * `params` - Model parameters as (name, tensor) pairs
121/// * `path` - Path to save checkpoint
122/// * `metadata` - Checkpoint metadata
123pub fn save_checkpoint<P: AsRef<Path>>(
124    params: HashMap<String, Tensor>,
125    path: P,
126    metadata: CheckpointMetadata,
127) -> crate::Result<()> {
128    let checkpoint = Checkpoint::new(params, metadata);
129    checkpoint.save(path)
130}
131
132/// Load model parameters from checkpoint
133///
134/// # Arguments
135/// * `path` - Path to checkpoint file
136/// * `device` - Device to load tensors on
137///
138/// # Returns
139/// HashMap of parameter name to tensor
140pub fn load_checkpoint<P: AsRef<Path>>(
141    path: P,
142    device: &Device,
143) -> crate::Result<HashMap<String, Tensor>> {
144    Checkpoint::load_weights(path, device)
145}
146
147#[cfg(test)]
148mod tests {
149    use super::*;
150    use std::fs;
151
152    #[test]
153    fn test_checkpoint_metadata() {
154        let metadata = CheckpointMetadata {
155            step: 1000,
156            lr: 0.001,
157            loss: Some(0.5),
158            config: Some("{}".to_string()),
159        };
160
161        assert_eq!(metadata.step, 1000);
162        assert_eq!(metadata.lr, 0.001);
163    }
164
165    #[test]
166    fn test_checkpoint_creation() -> Result<()> {
167        let device = Device::Cpu;
168        let mut tensors = HashMap::new();
169        tensors.insert(
170            "weight".to_string(),
171            Tensor::ones((10, 10), DType::F32, &device)?,
172        );
173
174        let metadata = CheckpointMetadata::default();
175        let checkpoint = Checkpoint::new(tensors, metadata);
176
177        assert_eq!(checkpoint.tensors.len(), 1);
178
179        Ok(())
180    }
181
182    #[test]
183    fn test_save_load_checkpoint() -> Result<()> {
184        let device = Device::Cpu;
185        let mut tensors = HashMap::new();
186        tensors.insert(
187            "weight".to_string(),
188            Tensor::ones((5, 5), DType::F32, &device)?,
189        );
190
191        let metadata = CheckpointMetadata {
192            step: 500,
193            lr: 0.0005,
194            loss: Some(0.25),
195            config: None,
196        };
197
198        let temp_path = std::path::Path::new("test_checkpoint.safetensors");
199
200        // Save
201        let result = save_checkpoint(tensors, temp_path, metadata.clone());
202
203        // Clean up
204        if temp_path.exists() {
205            fs::remove_file(temp_path).ok();
206        }
207
208        Ok(())
209    }
210}