1use crate::{Scirs2Exec, TlBackendError, TlBackendResult};
32use std::collections::HashMap;
33use std::fs::File;
34use std::io::{BufReader, BufWriter, Write};
35use std::path::Path;
36use std::time::{SystemTime, UNIX_EPOCH};
37
38#[derive(Debug, Clone)]
40pub struct CheckpointConfig {
41 pub enable_compression: bool,
43
44 pub include_tape: bool,
46
47 pub verify_checksum: bool,
49
50 pub incremental: bool,
52}
53
54impl Default for CheckpointConfig {
55 fn default() -> Self {
56 Self {
57 enable_compression: false,
58 include_tape: false,
59 verify_checksum: true,
60 incremental: false,
61 }
62 }
63}
64
65impl CheckpointConfig {
66 pub fn for_training() -> Self {
68 Self {
69 enable_compression: false,
70 include_tape: true,
71 verify_checksum: true,
72 incremental: false,
73 }
74 }
75
76 pub fn for_inference() -> Self {
78 Self {
79 enable_compression: true,
80 include_tape: false,
81 verify_checksum: true,
82 incremental: false,
83 }
84 }
85
86 pub fn incremental() -> Self {
88 Self {
89 enable_compression: false,
90 include_tape: true,
91 verify_checksum: true,
92 incremental: true,
93 }
94 }
95}
96
97#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
99pub struct CheckpointMetadata {
100 pub iteration: usize,
102
103 pub timestamp: u64,
105
106 pub version: String,
108
109 pub tensor_count: usize,
111
112 pub total_bytes: usize,
114
115 pub custom: HashMap<String, String>,
117
118 pub checksum: Option<String>,
120}
121
122#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
124struct SerializedTensor {
125 name: String,
126 shape: Vec<usize>,
127 data: Vec<f64>,
128}
129
130#[derive(Debug, Clone)]
132pub struct Checkpoint {
133 pub metadata: CheckpointMetadata,
135
136 tensors: Vec<SerializedTensor>,
138
139 #[allow(dead_code)]
141 config: CheckpointConfig,
142}
143
144impl Checkpoint {
145 pub fn from_executor(executor: &Scirs2Exec, iteration: usize) -> TlBackendResult<Self> {
147 Self::from_executor_with_config(executor, iteration, &CheckpointConfig::default())
148 }
149
150 pub fn from_executor_with_config(
152 executor: &Scirs2Exec,
153 iteration: usize,
154 config: &CheckpointConfig,
155 ) -> TlBackendResult<Self> {
156 let mut tensors = Vec::new();
157 let mut total_bytes = 0;
158
159 for (name, tensor) in &executor.tensors {
161 let shape = tensor.shape().to_vec();
162 let data: Vec<f64> = tensor.iter().copied().collect();
163 total_bytes += data.len() * std::mem::size_of::<f64>();
164
165 tensors.push(SerializedTensor {
166 name: name.clone(),
167 shape,
168 data,
169 });
170 }
171
172 let timestamp = SystemTime::now()
173 .duration_since(UNIX_EPOCH)
174 .map_err(|e| TlBackendError::execution(format!("Failed to get timestamp: {}", e)))?
175 .as_secs();
176
177 let checksum = if config.verify_checksum {
178 Some(Self::compute_checksum(&tensors))
179 } else {
180 None
181 };
182
183 let metadata = CheckpointMetadata {
184 iteration,
185 timestamp,
186 version: "0.1.0".to_string(),
187 tensor_count: tensors.len(),
188 total_bytes,
189 custom: HashMap::new(),
190 checksum,
191 };
192
193 Ok(Checkpoint {
194 metadata,
195 tensors,
196 config: config.clone(),
197 })
198 }
199
200 pub fn save<P: AsRef<Path>>(&self, path: P) -> TlBackendResult<()> {
202 let file = File::create(path.as_ref()).map_err(|e| {
203 TlBackendError::execution(format!("Failed to create checkpoint file: {}", e))
204 })?;
205 let mut writer = BufWriter::new(file);
206
207 let checkpoint_data = CheckpointData {
209 metadata: self.metadata.clone(),
210 tensors: self.tensors.clone(),
211 };
212
213 serde_json::to_writer(&mut writer, &checkpoint_data).map_err(|e| {
214 TlBackendError::execution(format!("Failed to serialize checkpoint: {}", e))
215 })?;
216
217 writer
218 .flush()
219 .map_err(|e| TlBackendError::execution(format!("Failed to flush checkpoint: {}", e)))?;
220
221 Ok(())
222 }
223
224 pub fn load<P: AsRef<Path>>(path: P) -> TlBackendResult<Self> {
226 Self::load_with_config(path, &CheckpointConfig::default())
227 }
228
229 pub fn load_with_config<P: AsRef<Path>>(
231 path: P,
232 config: &CheckpointConfig,
233 ) -> TlBackendResult<Self> {
234 let file = File::open(path.as_ref()).map_err(|e| {
235 TlBackendError::execution(format!("Failed to open checkpoint file: {}", e))
236 })?;
237 let reader = BufReader::new(file);
238
239 let checkpoint_data: CheckpointData = serde_json::from_reader(reader).map_err(|e| {
240 TlBackendError::execution(format!("Failed to deserialize checkpoint: {}", e))
241 })?;
242
243 if config.verify_checksum {
245 if let Some(ref expected_checksum) = checkpoint_data.metadata.checksum {
246 let actual_checksum = Self::compute_checksum(&checkpoint_data.tensors);
247 if &actual_checksum != expected_checksum {
248 return Err(TlBackendError::execution(
249 "Checkpoint checksum verification failed",
250 ));
251 }
252 }
253 }
254
255 Ok(Checkpoint {
256 metadata: checkpoint_data.metadata,
257 tensors: checkpoint_data.tensors,
258 config: config.clone(),
259 })
260 }
261
262 pub fn restore(&self) -> TlBackendResult<Scirs2Exec> {
264 let mut executor = Scirs2Exec::new();
265
266 for serialized in &self.tensors {
268 let tensor = scirs2_core::ndarray::ArrayD::from_shape_vec(
269 serialized.shape.clone(),
270 serialized.data.clone(),
271 )
272 .map_err(|e| {
273 TlBackendError::execution(format!(
274 "Failed to restore tensor {}: {}",
275 serialized.name, e
276 ))
277 })?;
278
279 executor.add_tensor(&serialized.name, tensor);
280 }
281
282 Ok(executor)
283 }
284
285 pub fn restore_into(&self, executor: &mut Scirs2Exec) -> TlBackendResult<()> {
287 for serialized in &self.tensors {
288 let tensor = scirs2_core::ndarray::ArrayD::from_shape_vec(
289 serialized.shape.clone(),
290 serialized.data.clone(),
291 )
292 .map_err(|e| {
293 TlBackendError::execution(format!(
294 "Failed to restore tensor {}: {}",
295 serialized.name, e
296 ))
297 })?;
298
299 executor.add_tensor(&serialized.name, tensor);
300 }
301
302 Ok(())
303 }
304
305 pub fn add_metadata(&mut self, key: String, value: String) {
307 self.metadata.custom.insert(key, value);
308 }
309
310 pub fn get_metadata(&self, key: &str) -> Option<&String> {
312 self.metadata.custom.get(key)
313 }
314
315 fn compute_checksum(tensors: &[SerializedTensor]) -> String {
317 use std::collections::hash_map::DefaultHasher;
318 use std::hash::{Hash, Hasher};
319
320 let mut hasher = DefaultHasher::new();
321
322 for tensor in tensors {
323 tensor.name.hash(&mut hasher);
324 tensor.shape.hash(&mut hasher);
325 for &value in &tensor.data {
327 value.to_bits().hash(&mut hasher);
328 }
329 }
330
331 format!("{:x}", hasher.finish())
332 }
333
334 pub fn size_bytes(&self) -> usize {
336 self.metadata.total_bytes
337 }
338
339 pub fn size_human_readable(&self) -> String {
341 let bytes = self.metadata.total_bytes;
342 if bytes < 1024 {
343 format!("{} bytes", bytes)
344 } else if bytes < 1024 * 1024 {
345 format!("{:.2} KB", bytes as f64 / 1024.0)
346 } else if bytes < 1024 * 1024 * 1024 {
347 format!("{:.2} MB", bytes as f64 / (1024.0 * 1024.0))
348 } else {
349 format!("{:.2} GB", bytes as f64 / (1024.0 * 1024.0 * 1024.0))
350 }
351 }
352}
353
354#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
356struct CheckpointData {
357 metadata: CheckpointMetadata,
358 tensors: Vec<SerializedTensor>,
359}
360
361pub struct CheckpointManager {
363 checkpoint_dir: std::path::PathBuf,
365
366 max_checkpoints: Option<usize>,
368
369 filename_pattern: String,
371}
372
373impl CheckpointManager {
374 pub fn new<P: AsRef<Path>>(checkpoint_dir: P) -> TlBackendResult<Self> {
376 let checkpoint_dir = checkpoint_dir.as_ref().to_path_buf();
377
378 if !checkpoint_dir.exists() {
380 std::fs::create_dir_all(&checkpoint_dir).map_err(|e| {
381 TlBackendError::execution(format!("Failed to create checkpoint directory: {}", e))
382 })?;
383 }
384
385 Ok(Self {
386 checkpoint_dir,
387 max_checkpoints: Some(5), filename_pattern: "checkpoint_iter_{}.json".to_string(),
389 })
390 }
391
392 pub fn set_max_checkpoints(&mut self, max: Option<usize>) {
394 self.max_checkpoints = max;
395 }
396
397 pub fn set_filename_pattern(&mut self, pattern: String) {
399 self.filename_pattern = pattern;
400 }
401
402 pub fn save_checkpoint(
404 &self,
405 executor: &Scirs2Exec,
406 iteration: usize,
407 ) -> TlBackendResult<std::path::PathBuf> {
408 let checkpoint = Checkpoint::from_executor(executor, iteration)?;
409 let filename = self.filename_pattern.replace("{}", &iteration.to_string());
410 let path = self.checkpoint_dir.join(filename);
411
412 checkpoint.save(&path)?;
413
414 if let Some(max) = self.max_checkpoints {
416 self.cleanup_old_checkpoints(max)?;
417 }
418
419 Ok(path)
420 }
421
422 pub fn load_latest(&self) -> TlBackendResult<Checkpoint> {
424 let latest_path = self.find_latest_checkpoint()?;
425 Checkpoint::load(latest_path)
426 }
427
428 fn find_latest_checkpoint(&self) -> TlBackendResult<std::path::PathBuf> {
430 let entries = std::fs::read_dir(&self.checkpoint_dir).map_err(|e| {
431 TlBackendError::execution(format!("Failed to read checkpoint directory: {}", e))
432 })?;
433
434 let mut checkpoints: Vec<_> = entries
435 .filter_map(|e| e.ok())
436 .filter(|e| {
437 e.path()
438 .extension()
439 .and_then(|s| s.to_str())
440 .map(|s| s == "json")
441 .unwrap_or(false)
442 })
443 .collect();
444
445 checkpoints.sort_by_key(|e| {
446 e.metadata()
447 .ok()
448 .and_then(|m| m.modified().ok())
449 .unwrap_or(SystemTime::UNIX_EPOCH)
450 });
451
452 checkpoints
453 .last()
454 .map(|e| e.path())
455 .ok_or_else(|| TlBackendError::execution("No checkpoints found"))
456 }
457
458 fn cleanup_old_checkpoints(&self, max: usize) -> TlBackendResult<()> {
460 let entries = std::fs::read_dir(&self.checkpoint_dir).map_err(|e| {
461 TlBackendError::execution(format!("Failed to read checkpoint directory: {}", e))
462 })?;
463
464 let mut checkpoints: Vec<_> = entries
465 .filter_map(|e| e.ok())
466 .filter(|e| {
467 e.path()
468 .extension()
469 .and_then(|s| s.to_str())
470 .map(|s| s == "json")
471 .unwrap_or(false)
472 })
473 .collect();
474
475 checkpoints.sort_by_key(|e| {
476 e.metadata()
477 .ok()
478 .and_then(|m| m.modified().ok())
479 .unwrap_or(SystemTime::UNIX_EPOCH)
480 });
481
482 let to_remove = checkpoints.len().saturating_sub(max);
484 for entry in checkpoints.iter().take(to_remove) {
485 std::fs::remove_file(entry.path()).ok();
486 }
487
488 Ok(())
489 }
490
491 pub fn list_checkpoints(&self) -> TlBackendResult<Vec<std::path::PathBuf>> {
493 let entries = std::fs::read_dir(&self.checkpoint_dir).map_err(|e| {
494 TlBackendError::execution(format!("Failed to read checkpoint directory: {}", e))
495 })?;
496
497 let mut checkpoints: Vec<_> = entries
498 .filter_map(|e| e.ok())
499 .filter(|e| {
500 e.path()
501 .extension()
502 .and_then(|s| s.to_str())
503 .map(|s| s == "json")
504 .unwrap_or(false)
505 })
506 .map(|e| e.path())
507 .collect();
508
509 checkpoints.sort();
510 Ok(checkpoints)
511 }
512}
513
514#[cfg(test)]
515mod tests {
516 use super::*;
517 use scirs2_core::ndarray::ArrayD;
518
519 #[test]
520 fn test_checkpoint_config_default() {
521 let config = CheckpointConfig::default();
522 assert!(!config.enable_compression);
523 assert!(!config.include_tape);
524 assert!(config.verify_checksum);
525 assert!(!config.incremental);
526 }
527
528 #[test]
529 fn test_checkpoint_config_training() {
530 let config = CheckpointConfig::for_training();
531 assert!(!config.enable_compression);
532 assert!(config.include_tape);
533 assert!(config.verify_checksum);
534 }
535
536 #[test]
537 fn test_checkpoint_config_inference() {
538 let config = CheckpointConfig::for_inference();
539 assert!(config.enable_compression);
540 assert!(!config.include_tape);
541 assert!(config.verify_checksum);
542 }
543
544 #[test]
545 fn test_checkpoint_from_executor() {
546 let mut executor = Scirs2Exec::new();
547 let tensor =
548 ArrayD::from_shape_vec(vec![2, 3], vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
549 executor.add_tensor("test_tensor", tensor);
550
551 let checkpoint = Checkpoint::from_executor(&executor, 1).unwrap();
552
553 assert_eq!(checkpoint.metadata.iteration, 1);
554 assert_eq!(checkpoint.metadata.tensor_count, 1);
555 assert!(checkpoint.metadata.total_bytes > 0);
556 }
557
558 #[test]
559 fn test_checkpoint_save_and_load() {
560 let mut executor = Scirs2Exec::new();
561 let tensor = ArrayD::from_shape_vec(vec![3], vec![1.0, 2.0, 3.0]).unwrap();
562 executor.add_tensor("weights", tensor);
563
564 let checkpoint = Checkpoint::from_executor(&executor, 5).unwrap();
566 let temp_path = std::env::temp_dir().join("test_checkpoint.json");
567 checkpoint.save(&temp_path).unwrap();
568
569 let loaded = Checkpoint::load(&temp_path).unwrap();
571 assert_eq!(loaded.metadata.iteration, 5);
572 assert_eq!(loaded.metadata.tensor_count, 1);
573
574 std::fs::remove_file(temp_path).ok();
576 }
577
578 #[test]
579 fn test_checkpoint_restore() {
580 let mut executor = Scirs2Exec::new();
581 let tensor = ArrayD::from_shape_vec(vec![2], vec![10.0, 20.0]).unwrap();
582 executor.add_tensor("params", tensor.clone());
583
584 let checkpoint = Checkpoint::from_executor(&executor, 1).unwrap();
586 let restored_executor = checkpoint.restore().unwrap();
587
588 let restored_tensor = restored_executor.get_tensor("params").unwrap();
590 assert_eq!(restored_tensor.shape(), tensor.shape());
591 assert_eq!(restored_tensor[[0]], 10.0);
592 assert_eq!(restored_tensor[[1]], 20.0);
593 }
594
595 #[test]
596 fn test_checkpoint_metadata() {
597 let mut executor = Scirs2Exec::new();
598 let tensor = ArrayD::from_shape_vec(vec![1], vec![1.0]).unwrap();
599 executor.add_tensor("x", tensor);
600
601 let mut checkpoint = Checkpoint::from_executor(&executor, 10).unwrap();
602 checkpoint.add_metadata("learning_rate".to_string(), "0.001".to_string());
603 checkpoint.add_metadata("optimizer".to_string(), "adam".to_string());
604
605 assert_eq!(
606 checkpoint.get_metadata("learning_rate"),
607 Some(&"0.001".to_string())
608 );
609 assert_eq!(
610 checkpoint.get_metadata("optimizer"),
611 Some(&"adam".to_string())
612 );
613 assert_eq!(checkpoint.get_metadata("missing"), None);
614 }
615
616 #[test]
617 fn test_checkpoint_size_human_readable() {
618 let mut executor = Scirs2Exec::new();
619 let tensor = ArrayD::from_shape_vec(vec![1000], vec![1.0; 1000]).unwrap();
620 executor.add_tensor("big_tensor", tensor);
621
622 let checkpoint = Checkpoint::from_executor(&executor, 1).unwrap();
623 let size_str = checkpoint.size_human_readable();
624
625 assert!(size_str.contains("KB") || size_str.contains("bytes"));
627 }
628
629 #[test]
630 fn test_checkpoint_manager() {
631 let temp_dir = std::env::temp_dir().join("test_checkpoints");
632 let manager = CheckpointManager::new(&temp_dir).unwrap();
633
634 let mut executor = Scirs2Exec::new();
635 let tensor = ArrayD::from_shape_vec(vec![2], vec![1.0, 2.0]).unwrap();
636 executor.add_tensor("data", tensor);
637
638 let path = manager.save_checkpoint(&executor, 1).unwrap();
640 assert!(path.exists());
641
642 let checkpoints = manager.list_checkpoints().unwrap();
644 assert_eq!(checkpoints.len(), 1);
645
646 std::fs::remove_dir_all(temp_dir).ok();
648 }
649
650 #[test]
651 fn test_checkpoint_manager_cleanup() {
652 let temp_dir = std::env::temp_dir().join("test_checkpoints_cleanup");
653 let mut manager = CheckpointManager::new(&temp_dir).unwrap();
654 manager.set_max_checkpoints(Some(3));
655
656 let mut executor = Scirs2Exec::new();
657 let tensor = ArrayD::from_shape_vec(vec![1], vec![1.0]).unwrap();
658 executor.add_tensor("x", tensor);
659
660 for i in 1..=5 {
662 manager.save_checkpoint(&executor, i).unwrap();
663 }
664
665 let checkpoints = manager.list_checkpoints().unwrap();
667 assert!(checkpoints.len() <= 3);
668
669 std::fs::remove_dir_all(temp_dir).ok();
671 }
672
673 #[test]
674 fn test_checkpoint_checksum_verification() {
675 let mut executor = Scirs2Exec::new();
676 let tensor = ArrayD::from_shape_vec(vec![2], vec![1.0, 2.0]).unwrap();
677 executor.add_tensor("data", tensor);
678
679 let config = CheckpointConfig {
680 verify_checksum: true,
681 ..Default::default()
682 };
683
684 let checkpoint = Checkpoint::from_executor_with_config(&executor, 1, &config).unwrap();
685 assert!(checkpoint.metadata.checksum.is_some());
686
687 let temp_path = std::env::temp_dir().join("test_checksum.json");
688 checkpoint.save(&temp_path).unwrap();
689
690 let loaded = Checkpoint::load_with_config(&temp_path, &config).unwrap();
692 assert_eq!(loaded.metadata.checksum, checkpoint.metadata.checksum);
693
694 std::fs::remove_file(temp_path).ok();
696 }
697}