Skip to main content

tensorlogic_scirs_backend/
checkpoint.rs

1//! Checkpoint and resume functionality for training workflows.
2//!
3//! This module provides utilities to save and restore executor state during training,
4//! enabling mid-training checkpoints, recovery from failures, and incremental compilation.
5//!
6//! ## Features
7//!
8//! - **State Serialization**: Save executor tensors and forward tape
9//! - **Incremental Checkpoints**: Save only changed tensors
10//! - **Compression**: Optional compression for checkpoint files
11//! - **Metadata**: Track training iteration, timestamp, and custom data
12//! - **Verification**: Checksum validation for data integrity
13//!
14//! ## Example
15//!
16//! ```rust,ignore
17//! use tensorlogic_scirs_backend::{Scirs2Exec, Checkpoint, CheckpointConfig};
18//!
19//! let mut executor = Scirs2Exec::new();
20//! // ... training loop ...
21//!
22//! // Save checkpoint
23//! let checkpoint = Checkpoint::from_executor(&executor, iteration)?;
24//! checkpoint.save("checkpoint_epoch_5.bin")?;
25//!
26//! // Restore checkpoint
27//! let checkpoint = Checkpoint::load("checkpoint_epoch_5.bin")?;
28//! let mut executor = checkpoint.restore()?;
29//! ```
30
31use crate::{Scirs2Exec, TlBackendError, TlBackendResult};
32use std::collections::HashMap;
33use std::fs::File;
34use std::io::{BufReader, BufWriter, Write};
35use std::path::Path;
36use std::time::{SystemTime, UNIX_EPOCH};
37
38/// Configuration for checkpoint creation and loading.
39#[derive(Debug, Clone)]
40pub struct CheckpointConfig {
41    /// Enable compression (reduces file size but increases save time)
42    pub enable_compression: bool,
43
44    /// Include forward tape in checkpoint (needed for gradient computation)
45    pub include_tape: bool,
46
47    /// Verify checksum on load
48    pub verify_checksum: bool,
49
50    /// Save only tensors that changed since last checkpoint
51    pub incremental: bool,
52}
53
54impl Default for CheckpointConfig {
55    fn default() -> Self {
56        Self {
57            enable_compression: false,
58            include_tape: false,
59            verify_checksum: true,
60            incremental: false,
61        }
62    }
63}
64
65impl CheckpointConfig {
66    /// Create a configuration for training checkpoints (includes tape).
67    pub fn for_training() -> Self {
68        Self {
69            enable_compression: false,
70            include_tape: true,
71            verify_checksum: true,
72            incremental: false,
73        }
74    }
75
76    /// Create a configuration for inference checkpoints (no tape, compressed).
77    pub fn for_inference() -> Self {
78        Self {
79            enable_compression: true,
80            include_tape: false,
81            verify_checksum: true,
82            incremental: false,
83        }
84    }
85
86    /// Create a configuration for incremental checkpoints.
87    pub fn incremental() -> Self {
88        Self {
89            enable_compression: false,
90            include_tape: true,
91            verify_checksum: true,
92            incremental: true,
93        }
94    }
95}
96
97/// Metadata about a checkpoint.
98#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
99pub struct CheckpointMetadata {
100    /// Training iteration/epoch number
101    pub iteration: usize,
102
103    /// Timestamp when checkpoint was created
104    pub timestamp: u64,
105
106    /// Version of the checkpoint format
107    pub version: String,
108
109    /// Number of tensors in checkpoint
110    pub tensor_count: usize,
111
112    /// Total size in bytes (uncompressed)
113    pub total_bytes: usize,
114
115    /// Custom metadata (user-defined)
116    pub custom: HashMap<String, String>,
117
118    /// Checksum for verification (if enabled)
119    pub checksum: Option<String>,
120}
121
122/// Serialized tensor data.
123#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
124struct SerializedTensor {
125    name: String,
126    shape: Vec<usize>,
127    data: Vec<f64>,
128}
129
130/// A checkpoint containing executor state.
131#[derive(Debug, Clone)]
132pub struct Checkpoint {
133    /// Checkpoint metadata
134    pub metadata: CheckpointMetadata,
135
136    /// Serialized tensors
137    tensors: Vec<SerializedTensor>,
138
139    /// Configuration used to create this checkpoint
140    #[allow(dead_code)]
141    config: CheckpointConfig,
142}
143
144impl Checkpoint {
145    /// Create a checkpoint from an executor.
146    pub fn from_executor(executor: &Scirs2Exec, iteration: usize) -> TlBackendResult<Self> {
147        Self::from_executor_with_config(executor, iteration, &CheckpointConfig::default())
148    }
149
150    /// Create a checkpoint with custom configuration.
151    pub fn from_executor_with_config(
152        executor: &Scirs2Exec,
153        iteration: usize,
154        config: &CheckpointConfig,
155    ) -> TlBackendResult<Self> {
156        let mut tensors = Vec::new();
157        let mut total_bytes = 0;
158
159        // Serialize all tensors
160        for (name, tensor) in &executor.tensors {
161            let shape = tensor.shape().to_vec();
162            let data: Vec<f64> = tensor.iter().copied().collect();
163            total_bytes += data.len() * std::mem::size_of::<f64>();
164
165            tensors.push(SerializedTensor {
166                name: name.clone(),
167                shape,
168                data,
169            });
170        }
171
172        let timestamp = SystemTime::now()
173            .duration_since(UNIX_EPOCH)
174            .map_err(|e| TlBackendError::execution(format!("Failed to get timestamp: {}", e)))?
175            .as_secs();
176
177        let checksum = if config.verify_checksum {
178            Some(Self::compute_checksum(&tensors))
179        } else {
180            None
181        };
182
183        let metadata = CheckpointMetadata {
184            iteration,
185            timestamp,
186            version: "0.1.0".to_string(),
187            tensor_count: tensors.len(),
188            total_bytes,
189            custom: HashMap::new(),
190            checksum,
191        };
192
193        Ok(Checkpoint {
194            metadata,
195            tensors,
196            config: config.clone(),
197        })
198    }
199
200    /// Save checkpoint to a file.
201    pub fn save<P: AsRef<Path>>(&self, path: P) -> TlBackendResult<()> {
202        let file = File::create(path.as_ref()).map_err(|e| {
203            TlBackendError::execution(format!("Failed to create checkpoint file: {}", e))
204        })?;
205        let mut writer = BufWriter::new(file);
206
207        // Serialize to JSON (could use bincode for better performance)
208        let checkpoint_data = CheckpointData {
209            metadata: self.metadata.clone(),
210            tensors: self.tensors.clone(),
211        };
212
213        serde_json::to_writer(&mut writer, &checkpoint_data).map_err(|e| {
214            TlBackendError::execution(format!("Failed to serialize checkpoint: {}", e))
215        })?;
216
217        writer
218            .flush()
219            .map_err(|e| TlBackendError::execution(format!("Failed to flush checkpoint: {}", e)))?;
220
221        Ok(())
222    }
223
224    /// Load checkpoint from a file.
225    pub fn load<P: AsRef<Path>>(path: P) -> TlBackendResult<Self> {
226        Self::load_with_config(path, &CheckpointConfig::default())
227    }
228
229    /// Load checkpoint with custom configuration.
230    pub fn load_with_config<P: AsRef<Path>>(
231        path: P,
232        config: &CheckpointConfig,
233    ) -> TlBackendResult<Self> {
234        let file = File::open(path.as_ref()).map_err(|e| {
235            TlBackendError::execution(format!("Failed to open checkpoint file: {}", e))
236        })?;
237        let reader = BufReader::new(file);
238
239        let checkpoint_data: CheckpointData = serde_json::from_reader(reader).map_err(|e| {
240            TlBackendError::execution(format!("Failed to deserialize checkpoint: {}", e))
241        })?;
242
243        // Verify checksum if requested
244        if config.verify_checksum {
245            if let Some(ref expected_checksum) = checkpoint_data.metadata.checksum {
246                let actual_checksum = Self::compute_checksum(&checkpoint_data.tensors);
247                if &actual_checksum != expected_checksum {
248                    return Err(TlBackendError::execution(
249                        "Checkpoint checksum verification failed",
250                    ));
251                }
252            }
253        }
254
255        Ok(Checkpoint {
256            metadata: checkpoint_data.metadata,
257            tensors: checkpoint_data.tensors,
258            config: config.clone(),
259        })
260    }
261
262    /// Restore an executor from this checkpoint.
263    pub fn restore(&self) -> TlBackendResult<Scirs2Exec> {
264        let mut executor = Scirs2Exec::new();
265
266        // Deserialize tensors
267        for serialized in &self.tensors {
268            let tensor = scirs2_core::ndarray::ArrayD::from_shape_vec(
269                serialized.shape.clone(),
270                serialized.data.clone(),
271            )
272            .map_err(|e| {
273                TlBackendError::execution(format!(
274                    "Failed to restore tensor {}: {}",
275                    serialized.name, e
276                ))
277            })?;
278
279            executor.add_tensor(&serialized.name, tensor);
280        }
281
282        Ok(executor)
283    }
284
285    /// Restore tensors into an existing executor.
286    pub fn restore_into(&self, executor: &mut Scirs2Exec) -> TlBackendResult<()> {
287        for serialized in &self.tensors {
288            let tensor = scirs2_core::ndarray::ArrayD::from_shape_vec(
289                serialized.shape.clone(),
290                serialized.data.clone(),
291            )
292            .map_err(|e| {
293                TlBackendError::execution(format!(
294                    "Failed to restore tensor {}: {}",
295                    serialized.name, e
296                ))
297            })?;
298
299            executor.add_tensor(&serialized.name, tensor);
300        }
301
302        Ok(())
303    }
304
305    /// Add custom metadata to the checkpoint.
306    pub fn add_metadata(&mut self, key: String, value: String) {
307        self.metadata.custom.insert(key, value);
308    }
309
310    /// Get custom metadata from the checkpoint.
311    pub fn get_metadata(&self, key: &str) -> Option<&String> {
312        self.metadata.custom.get(key)
313    }
314
315    /// Compute checksum for verification.
316    fn compute_checksum(tensors: &[SerializedTensor]) -> String {
317        use std::collections::hash_map::DefaultHasher;
318        use std::hash::{Hash, Hasher};
319
320        let mut hasher = DefaultHasher::new();
321
322        for tensor in tensors {
323            tensor.name.hash(&mut hasher);
324            tensor.shape.hash(&mut hasher);
325            // Hash float data as bytes
326            for &value in &tensor.data {
327                value.to_bits().hash(&mut hasher);
328            }
329        }
330
331        format!("{:x}", hasher.finish())
332    }
333
334    /// Get the size of this checkpoint in bytes (uncompressed).
335    pub fn size_bytes(&self) -> usize {
336        self.metadata.total_bytes
337    }
338
339    /// Get a human-readable size string.
340    pub fn size_human_readable(&self) -> String {
341        let bytes = self.metadata.total_bytes;
342        if bytes < 1024 {
343            format!("{} bytes", bytes)
344        } else if bytes < 1024 * 1024 {
345            format!("{:.2} KB", bytes as f64 / 1024.0)
346        } else if bytes < 1024 * 1024 * 1024 {
347            format!("{:.2} MB", bytes as f64 / (1024.0 * 1024.0))
348        } else {
349            format!("{:.2} GB", bytes as f64 / (1024.0 * 1024.0 * 1024.0))
350        }
351    }
352}
353
354/// Internal checkpoint data structure for serialization.
355#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
356struct CheckpointData {
357    metadata: CheckpointMetadata,
358    tensors: Vec<SerializedTensor>,
359}
360
361/// Manager for handling multiple checkpoints.
362pub struct CheckpointManager {
363    /// Directory where checkpoints are stored
364    checkpoint_dir: std::path::PathBuf,
365
366    /// Maximum number of checkpoints to keep
367    max_checkpoints: Option<usize>,
368
369    /// Pattern for checkpoint filenames
370    filename_pattern: String,
371}
372
373impl CheckpointManager {
374    /// Create a new checkpoint manager.
375    pub fn new<P: AsRef<Path>>(checkpoint_dir: P) -> TlBackendResult<Self> {
376        let checkpoint_dir = checkpoint_dir.as_ref().to_path_buf();
377
378        // Create directory if it doesn't exist
379        if !checkpoint_dir.exists() {
380            std::fs::create_dir_all(&checkpoint_dir).map_err(|e| {
381                TlBackendError::execution(format!("Failed to create checkpoint directory: {}", e))
382            })?;
383        }
384
385        Ok(Self {
386            checkpoint_dir,
387            max_checkpoints: Some(5), // Keep last 5 checkpoints by default
388            filename_pattern: "checkpoint_iter_{}.json".to_string(),
389        })
390    }
391
392    /// Set the maximum number of checkpoints to keep.
393    pub fn set_max_checkpoints(&mut self, max: Option<usize>) {
394        self.max_checkpoints = max;
395    }
396
397    /// Set the filename pattern for checkpoints.
398    pub fn set_filename_pattern(&mut self, pattern: String) {
399        self.filename_pattern = pattern;
400    }
401
402    /// Save a checkpoint and manage old checkpoints.
403    pub fn save_checkpoint(
404        &self,
405        executor: &Scirs2Exec,
406        iteration: usize,
407    ) -> TlBackendResult<std::path::PathBuf> {
408        let checkpoint = Checkpoint::from_executor(executor, iteration)?;
409        let filename = self.filename_pattern.replace("{}", &iteration.to_string());
410        let path = self.checkpoint_dir.join(filename);
411
412        checkpoint.save(&path)?;
413
414        // Clean up old checkpoints if needed
415        if let Some(max) = self.max_checkpoints {
416            self.cleanup_old_checkpoints(max)?;
417        }
418
419        Ok(path)
420    }
421
422    /// Load the latest checkpoint.
423    pub fn load_latest(&self) -> TlBackendResult<Checkpoint> {
424        let latest_path = self.find_latest_checkpoint()?;
425        Checkpoint::load(latest_path)
426    }
427
428    /// Find the latest checkpoint file.
429    fn find_latest_checkpoint(&self) -> TlBackendResult<std::path::PathBuf> {
430        let entries = std::fs::read_dir(&self.checkpoint_dir).map_err(|e| {
431            TlBackendError::execution(format!("Failed to read checkpoint directory: {}", e))
432        })?;
433
434        let mut checkpoints: Vec<_> = entries
435            .filter_map(|e| e.ok())
436            .filter(|e| {
437                e.path()
438                    .extension()
439                    .and_then(|s| s.to_str())
440                    .map(|s| s == "json")
441                    .unwrap_or(false)
442            })
443            .collect();
444
445        checkpoints.sort_by_key(|e| {
446            e.metadata()
447                .ok()
448                .and_then(|m| m.modified().ok())
449                .unwrap_or(SystemTime::UNIX_EPOCH)
450        });
451
452        checkpoints
453            .last()
454            .map(|e| e.path())
455            .ok_or_else(|| TlBackendError::execution("No checkpoints found"))
456    }
457
458    /// Remove old checkpoints keeping only the most recent `max` checkpoints.
459    fn cleanup_old_checkpoints(&self, max: usize) -> TlBackendResult<()> {
460        let entries = std::fs::read_dir(&self.checkpoint_dir).map_err(|e| {
461            TlBackendError::execution(format!("Failed to read checkpoint directory: {}", e))
462        })?;
463
464        let mut checkpoints: Vec<_> = entries
465            .filter_map(|e| e.ok())
466            .filter(|e| {
467                e.path()
468                    .extension()
469                    .and_then(|s| s.to_str())
470                    .map(|s| s == "json")
471                    .unwrap_or(false)
472            })
473            .collect();
474
475        checkpoints.sort_by_key(|e| {
476            e.metadata()
477                .ok()
478                .and_then(|m| m.modified().ok())
479                .unwrap_or(SystemTime::UNIX_EPOCH)
480        });
481
482        // Remove oldest checkpoints
483        let to_remove = checkpoints.len().saturating_sub(max);
484        for entry in checkpoints.iter().take(to_remove) {
485            std::fs::remove_file(entry.path()).ok();
486        }
487
488        Ok(())
489    }
490
491    /// List all checkpoints in the directory.
492    pub fn list_checkpoints(&self) -> TlBackendResult<Vec<std::path::PathBuf>> {
493        let entries = std::fs::read_dir(&self.checkpoint_dir).map_err(|e| {
494            TlBackendError::execution(format!("Failed to read checkpoint directory: {}", e))
495        })?;
496
497        let mut checkpoints: Vec<_> = entries
498            .filter_map(|e| e.ok())
499            .filter(|e| {
500                e.path()
501                    .extension()
502                    .and_then(|s| s.to_str())
503                    .map(|s| s == "json")
504                    .unwrap_or(false)
505            })
506            .map(|e| e.path())
507            .collect();
508
509        checkpoints.sort();
510        Ok(checkpoints)
511    }
512}
513
514#[cfg(test)]
515mod tests {
516    use super::*;
517    use scirs2_core::ndarray::ArrayD;
518
519    #[test]
520    fn test_checkpoint_config_default() {
521        let config = CheckpointConfig::default();
522        assert!(!config.enable_compression);
523        assert!(!config.include_tape);
524        assert!(config.verify_checksum);
525        assert!(!config.incremental);
526    }
527
528    #[test]
529    fn test_checkpoint_config_training() {
530        let config = CheckpointConfig::for_training();
531        assert!(!config.enable_compression);
532        assert!(config.include_tape);
533        assert!(config.verify_checksum);
534    }
535
536    #[test]
537    fn test_checkpoint_config_inference() {
538        let config = CheckpointConfig::for_inference();
539        assert!(config.enable_compression);
540        assert!(!config.include_tape);
541        assert!(config.verify_checksum);
542    }
543
544    #[test]
545    fn test_checkpoint_from_executor() {
546        let mut executor = Scirs2Exec::new();
547        let tensor =
548            ArrayD::from_shape_vec(vec![2, 3], vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
549        executor.add_tensor("test_tensor", tensor);
550
551        let checkpoint = Checkpoint::from_executor(&executor, 1).unwrap();
552
553        assert_eq!(checkpoint.metadata.iteration, 1);
554        assert_eq!(checkpoint.metadata.tensor_count, 1);
555        assert!(checkpoint.metadata.total_bytes > 0);
556    }
557
558    #[test]
559    fn test_checkpoint_save_and_load() {
560        let mut executor = Scirs2Exec::new();
561        let tensor = ArrayD::from_shape_vec(vec![3], vec![1.0, 2.0, 3.0]).unwrap();
562        executor.add_tensor("weights", tensor);
563
564        // Save checkpoint
565        let checkpoint = Checkpoint::from_executor(&executor, 5).unwrap();
566        let temp_path = std::env::temp_dir().join("test_checkpoint.json");
567        checkpoint.save(&temp_path).unwrap();
568
569        // Load checkpoint
570        let loaded = Checkpoint::load(&temp_path).unwrap();
571        assert_eq!(loaded.metadata.iteration, 5);
572        assert_eq!(loaded.metadata.tensor_count, 1);
573
574        // Cleanup
575        std::fs::remove_file(temp_path).ok();
576    }
577
578    #[test]
579    fn test_checkpoint_restore() {
580        let mut executor = Scirs2Exec::new();
581        let tensor = ArrayD::from_shape_vec(vec![2], vec![10.0, 20.0]).unwrap();
582        executor.add_tensor("params", tensor.clone());
583
584        // Create and restore checkpoint
585        let checkpoint = Checkpoint::from_executor(&executor, 1).unwrap();
586        let restored_executor = checkpoint.restore().unwrap();
587
588        // Verify restored tensor
589        let restored_tensor = restored_executor.get_tensor("params").unwrap();
590        assert_eq!(restored_tensor.shape(), tensor.shape());
591        assert_eq!(restored_tensor[[0]], 10.0);
592        assert_eq!(restored_tensor[[1]], 20.0);
593    }
594
595    #[test]
596    fn test_checkpoint_metadata() {
597        let mut executor = Scirs2Exec::new();
598        let tensor = ArrayD::from_shape_vec(vec![1], vec![1.0]).unwrap();
599        executor.add_tensor("x", tensor);
600
601        let mut checkpoint = Checkpoint::from_executor(&executor, 10).unwrap();
602        checkpoint.add_metadata("learning_rate".to_string(), "0.001".to_string());
603        checkpoint.add_metadata("optimizer".to_string(), "adam".to_string());
604
605        assert_eq!(
606            checkpoint.get_metadata("learning_rate"),
607            Some(&"0.001".to_string())
608        );
609        assert_eq!(
610            checkpoint.get_metadata("optimizer"),
611            Some(&"adam".to_string())
612        );
613        assert_eq!(checkpoint.get_metadata("missing"), None);
614    }
615
616    #[test]
617    fn test_checkpoint_size_human_readable() {
618        let mut executor = Scirs2Exec::new();
619        let tensor = ArrayD::from_shape_vec(vec![1000], vec![1.0; 1000]).unwrap();
620        executor.add_tensor("big_tensor", tensor);
621
622        let checkpoint = Checkpoint::from_executor(&executor, 1).unwrap();
623        let size_str = checkpoint.size_human_readable();
624
625        // 1000 floats * 8 bytes = 8000 bytes = ~7.81 KB
626        assert!(size_str.contains("KB") || size_str.contains("bytes"));
627    }
628
629    #[test]
630    fn test_checkpoint_manager() {
631        let temp_dir = std::env::temp_dir().join("test_checkpoints");
632        let manager = CheckpointManager::new(&temp_dir).unwrap();
633
634        let mut executor = Scirs2Exec::new();
635        let tensor = ArrayD::from_shape_vec(vec![2], vec![1.0, 2.0]).unwrap();
636        executor.add_tensor("data", tensor);
637
638        // Save checkpoint
639        let path = manager.save_checkpoint(&executor, 1).unwrap();
640        assert!(path.exists());
641
642        // List checkpoints
643        let checkpoints = manager.list_checkpoints().unwrap();
644        assert_eq!(checkpoints.len(), 1);
645
646        // Cleanup
647        std::fs::remove_dir_all(temp_dir).ok();
648    }
649
650    #[test]
651    fn test_checkpoint_manager_cleanup() {
652        let temp_dir = std::env::temp_dir().join("test_checkpoints_cleanup");
653        let mut manager = CheckpointManager::new(&temp_dir).unwrap();
654        manager.set_max_checkpoints(Some(3));
655
656        let mut executor = Scirs2Exec::new();
657        let tensor = ArrayD::from_shape_vec(vec![1], vec![1.0]).unwrap();
658        executor.add_tensor("x", tensor);
659
660        // Save 5 checkpoints
661        for i in 1..=5 {
662            manager.save_checkpoint(&executor, i).unwrap();
663        }
664
665        // Should keep only last 3
666        let checkpoints = manager.list_checkpoints().unwrap();
667        assert!(checkpoints.len() <= 3);
668
669        // Cleanup
670        std::fs::remove_dir_all(temp_dir).ok();
671    }
672
673    #[test]
674    fn test_checkpoint_checksum_verification() {
675        let mut executor = Scirs2Exec::new();
676        let tensor = ArrayD::from_shape_vec(vec![2], vec![1.0, 2.0]).unwrap();
677        executor.add_tensor("data", tensor);
678
679        let config = CheckpointConfig {
680            verify_checksum: true,
681            ..Default::default()
682        };
683
684        let checkpoint = Checkpoint::from_executor_with_config(&executor, 1, &config).unwrap();
685        assert!(checkpoint.metadata.checksum.is_some());
686
687        let temp_path = std::env::temp_dir().join("test_checksum.json");
688        checkpoint.save(&temp_path).unwrap();
689
690        // Load with verification
691        let loaded = Checkpoint::load_with_config(&temp_path, &config).unwrap();
692        assert_eq!(loaded.metadata.checksum, checkpoint.metadata.checksum);
693
694        // Cleanup
695        std::fs::remove_file(temp_path).ok();
696    }
697}