1use crate::{FxGraph, TorshResult};
4use serde::{Deserialize, Serialize};
5use std::collections::HashMap;
6use std::fs;
7use std::path::{Path, PathBuf};
8use std::time::{SystemTime, UNIX_EPOCH};
9use torsh_core::error::TorshError;
10use torsh_tensor::Tensor;
11
12#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct CheckpointMetadata {
15 pub timestamp: u64,
17 pub step: u64,
19 pub loss: Option<f64>,
21 pub model_info: String,
23 pub user_metadata: HashMap<String, String>,
25 pub checksum: String,
27 pub version: u32,
29}
30
31impl CheckpointMetadata {
32 pub fn new(step: u64, model_info: String) -> Self {
34 let timestamp = SystemTime::now()
35 .duration_since(UNIX_EPOCH)
36 .unwrap_or_default()
37 .as_secs();
38
39 Self {
40 timestamp,
41 step,
42 loss: None,
43 model_info,
44 user_metadata: HashMap::new(),
45 checksum: String::new(),
46 version: 1,
47 }
48 }
49
50 pub fn with_loss(mut self, loss: f64) -> Self {
52 self.loss = Some(loss);
53 self
54 }
55
56 pub fn with_metadata(mut self, key: String, value: String) -> Self {
58 self.user_metadata.insert(key, value);
59 self
60 }
61
62 pub fn with_checksum(mut self, data: &[u8]) -> Self {
64 let hash = md5::compute(data);
65 self.checksum = format!("{hash:x}");
66 self
67 }
68
69 pub fn verify_checksum(&self, data: &[u8]) -> bool {
71 let hash = md5::compute(data);
72 let computed = format!("{hash:x}");
73 computed == self.checksum
74 }
75}
76
77#[derive(Debug, Clone)]
79pub struct CheckpointData {
80 pub graph: FxGraph,
82 pub tensor_states: HashMap<String, TensorState>,
84 pub optimizer_states: HashMap<String, OptimizerState>,
86 pub rng_states: HashMap<String, RngState>,
88 pub custom_states: HashMap<String, Vec<u8>>,
90 pub metadata: CheckpointMetadata,
92}
93
94#[derive(Debug, Clone, Serialize, Deserialize)]
96pub struct TensorState {
97 pub shape: Vec<usize>,
99 pub dtype: String,
101 pub data: Vec<u8>,
103 pub device_type: String,
105 pub requires_grad: bool,
107}
108
109impl TensorState {
110 pub fn from_tensor(tensor: &Tensor) -> TorshResult<Self> {
112 Ok(Self {
115 shape: tensor.shape().dims().to_vec(),
116 dtype: format!("{:?}", tensor.dtype()), data: vec![0; tensor.shape().numel() * tensor.dtype().size()],
118 device_type: "cpu".to_string(),
119 requires_grad: false, })
121 }
122
123 pub fn to_tensor(&self) -> TorshResult<Tensor> {
125 use torsh_tensor::creation::zeros;
128 zeros(&self.shape)
131 }
132}
133
134#[derive(Debug, Clone, Serialize, Deserialize)]
136pub struct OptimizerState {
137 pub optimizer_type: String,
139 pub learning_rate: f64,
141 pub step_count: u64,
143 pub parameters: HashMap<String, f64>,
145 pub param_states: HashMap<String, Vec<u8>>,
147}
148
149impl OptimizerState {
150 pub fn new(optimizer_type: String, learning_rate: f64) -> Self {
152 Self {
153 optimizer_type,
154 learning_rate,
155 step_count: 0,
156 parameters: HashMap::new(),
157 param_states: HashMap::new(),
158 }
159 }
160
161 pub fn with_parameter(mut self, name: String, value: f64) -> Self {
163 self.parameters.insert(name, value);
164 self
165 }
166
167 pub fn with_param_state(mut self, name: String, state: Vec<u8>) -> Self {
169 self.param_states.insert(name, state);
170 self
171 }
172}
173
174#[derive(Debug, Clone, Serialize, Deserialize)]
176pub struct RngState {
177 pub rng_type: String,
179 pub state: Vec<u8>,
181 pub seed: u64,
183}
184
185impl RngState {
186 pub fn new(rng_type: String, seed: u64) -> Self {
188 Self {
189 rng_type,
190 state: vec![],
191 seed,
192 }
193 }
194
195 pub fn with_state(mut self, state: Vec<u8>) -> Self {
197 self.state = state;
198 self
199 }
200}
201
202#[derive(Debug, Clone)]
204pub struct CheckpointOptions {
205 pub compress: bool,
207 pub compression_level: u32,
209 pub separate_tensors: bool,
211 pub max_history: Option<usize>,
213 pub create_latest_link: bool,
215 pub format: CheckpointFormat,
217}
218
219impl Default for CheckpointOptions {
220 fn default() -> Self {
221 Self {
222 compress: true,
223 compression_level: 6,
224 separate_tensors: false,
225 max_history: Some(5),
226 create_latest_link: true,
227 format: CheckpointFormat::Binary,
228 }
229 }
230}
231
232#[derive(Debug, Clone, Copy)]
234pub enum CheckpointFormat {
235 Binary,
237 Json,
239 Torsh,
241}
242
243pub struct CheckpointManager {
245 checkpoint_dir: PathBuf,
247 options: CheckpointOptions,
249 history: Vec<PathBuf>,
251}
252
253impl CheckpointManager {
254 pub fn new<P: AsRef<Path>>(checkpoint_dir: P, options: CheckpointOptions) -> TorshResult<Self> {
256 let checkpoint_dir = checkpoint_dir.as_ref().to_path_buf();
257
258 if !checkpoint_dir.exists() {
260 fs::create_dir_all(&checkpoint_dir).map_err(|e| {
261 TorshError::InvalidArgument(format!("Failed to create checkpoint directory: {e}"))
262 })?;
263 }
264
265 let mut manager = Self {
266 checkpoint_dir,
267 options,
268 history: vec![],
269 };
270
271 manager.load_history()?;
273
274 Ok(manager)
275 }
276
277 pub fn save_checkpoint(
279 &mut self,
280 data: CheckpointData,
281 name: Option<String>,
282 ) -> TorshResult<PathBuf> {
283 let filename = name.unwrap_or_else(|| {
284 let step = data.metadata.step;
285 format!("checkpoint_step_{step}.ckpt")
286 });
287
288 let checkpoint_path = self.checkpoint_dir.join(&filename);
289
290 let step = data.metadata.step;
293 let serialized = format!("checkpoint_placeholder_step_{step}").into_bytes();
294
295 let final_data = if self.options.compress {
297 self.compress_data(&serialized)?
298 } else {
299 serialized
300 };
301
302 fs::write(&checkpoint_path, &final_data)
304 .map_err(|e| TorshError::InvalidArgument(format!("Failed to write checkpoint: {e}")))?;
305
306 self.history.push(checkpoint_path.clone());
308 self.cleanup_old_checkpoints()?;
309
310 if self.options.create_latest_link {
312 self.create_latest_link(&checkpoint_path)?;
313 }
314
315 Ok(checkpoint_path)
316 }
317
318 pub fn load_checkpoint<P: AsRef<Path>>(&self, path: P) -> TorshResult<CheckpointData> {
320 let path = path.as_ref();
321
322 let file_data = fs::read(path)
324 .map_err(|e| TorshError::InvalidArgument(format!("Failed to read checkpoint: {e}")))?;
325
326 let data = if self.options.compress {
328 self.decompress_data(&file_data)?
329 } else {
330 file_data
331 };
332
333 let checkpoint = CheckpointData {
336 graph: crate::FxGraph::new(), tensor_states: HashMap::new(),
338 optimizer_states: HashMap::new(),
339 rng_states: HashMap::new(),
340 custom_states: HashMap::new(),
341 metadata: CheckpointMetadata::new(0, "placeholder".to_string()),
342 };
343
344 if !checkpoint.metadata.checksum.is_empty() && !checkpoint.metadata.verify_checksum(&data) {
346 return Err(TorshError::InvalidArgument(
347 "Checkpoint checksum verification failed".to_string(),
348 ));
349 }
350
351 Ok(checkpoint)
352 }
353
354 pub fn load_latest_checkpoint(&self) -> TorshResult<Option<CheckpointData>> {
356 let latest_path = self.checkpoint_dir.join("latest.ckpt");
357
358 if latest_path.exists() {
359 Ok(Some(self.load_checkpoint(latest_path)?))
360 } else if let Some(latest_from_history) = self.history.last() {
361 Ok(Some(self.load_checkpoint(latest_from_history)?))
362 } else {
363 Ok(None)
364 }
365 }
366
367 pub fn list_checkpoints(&self) -> Vec<PathBuf> {
369 self.history.clone()
370 }
371
372 pub fn delete_checkpoint<P: AsRef<Path>>(&mut self, path: P) -> TorshResult<()> {
374 let path = path.as_ref();
375
376 fs::remove_file(path).map_err(|e| {
377 TorshError::InvalidArgument(format!("Failed to delete checkpoint: {e}"))
378 })?;
379
380 self.history.retain(|p| p != path);
382
383 Ok(())
384 }
385
386 pub fn get_checkpoint_metadata<P: AsRef<Path>>(
388 &self,
389 path: P,
390 ) -> TorshResult<CheckpointMetadata> {
391 let checkpoint = self.load_checkpoint(path)?;
394 Ok(checkpoint.metadata)
395 }
396
397 fn compress_data(&self, data: &[u8]) -> TorshResult<Vec<u8>> {
399 use flate2::write::GzEncoder;
400 use flate2::Compression;
401 use std::io::Write;
402
403 let mut encoder =
404 GzEncoder::new(Vec::new(), Compression::new(self.options.compression_level));
405 encoder
406 .write_all(data)
407 .map_err(|e| TorshError::InvalidArgument(format!("Compression failed: {e}")))?;
408
409 encoder
410 .finish()
411 .map_err(|e| TorshError::InvalidArgument(format!("Compression failed: {e}")))
412 }
413
414 fn decompress_data(&self, data: &[u8]) -> TorshResult<Vec<u8>> {
416 use flate2::read::GzDecoder;
417 use std::io::Read;
418
419 let mut decoder = GzDecoder::new(data);
420 let mut decompressed = Vec::new();
421 decoder
422 .read_to_end(&mut decompressed)
423 .map_err(|e| TorshError::InvalidArgument(format!("Decompression failed: {e}")))?;
424
425 Ok(decompressed)
426 }
427
428 fn load_history(&mut self) -> TorshResult<()> {
430 let entries = fs::read_dir(&self.checkpoint_dir).map_err(|e| {
431 TorshError::InvalidArgument(format!("Failed to read checkpoint directory: {e}"))
432 })?;
433
434 let mut checkpoints = Vec::new();
435 for entry in entries {
436 let entry = entry.map_err(|e| {
437 TorshError::InvalidArgument(format!("Failed to read directory entry: {e}"))
438 })?;
439
440 let path = entry.path();
441 if path.is_file() && path.extension().is_some_and(|ext| ext == "ckpt") {
442 checkpoints.push(path);
443 }
444 }
445
446 checkpoints.sort_by_key(|path| {
448 fs::metadata(path)
449 .and_then(|meta| meta.modified())
450 .unwrap_or(SystemTime::UNIX_EPOCH)
451 });
452
453 self.history = checkpoints;
454 Ok(())
455 }
456
457 fn cleanup_old_checkpoints(&mut self) -> TorshResult<()> {
459 if let Some(max_history) = self.options.max_history {
460 while self.history.len() > max_history {
461 let old_checkpoint = self.history.remove(0);
462 let _ = fs::remove_file(&old_checkpoint);
463 }
464 }
465 Ok(())
466 }
467
468 fn create_latest_link(&self, checkpoint_path: &Path) -> TorshResult<()> {
470 let latest_path = self.checkpoint_dir.join("latest.ckpt");
471
472 if latest_path.exists() {
474 let _ = fs::remove_file(&latest_path);
475 }
476
477 #[cfg(unix)]
479 {
480 std::os::unix::fs::symlink(checkpoint_path, &latest_path).map_err(|e| {
481 TorshError::InvalidArgument(format!("Failed to create symlink: {e}"))
482 })?;
483 }
484
485 #[cfg(windows)]
486 {
487 fs::copy(checkpoint_path, &latest_path).map_err(|e| {
488 TorshError::InvalidArgument(format!("Failed to copy checkpoint: {e}"))
489 })?;
490 }
491
492 Ok(())
493 }
494}
495
496#[derive(Debug, Clone)]
498pub struct ExecutionCheckpoint {
499 pub graph: FxGraph,
501 pub execution_state: ExecutionState,
503 pub inputs: HashMap<String, TensorState>,
505 pub intermediate_results: HashMap<String, TensorState>,
507 pub remaining_nodes: Vec<String>,
509 pub metadata: CheckpointMetadata,
511}
512
513#[derive(Debug, Clone, Serialize, Deserialize)]
515pub struct ExecutionState {
516 pub current_node: Option<String>,
518 pub completed_nodes: Vec<String>,
520 pub failed_nodes: Vec<String>,
522 pub start_time: u64,
524 pub elapsed_time: u64,
526}
527
528pub struct ResumableInterpreter {
530 interpreter: crate::interpreter::GraphInterpreter,
532 checkpoint_manager: Option<CheckpointManager>,
534 current_checkpoint: Option<ExecutionCheckpoint>,
536 checkpoint_frequency: usize,
538}
539
540impl ResumableInterpreter {
541 pub fn new(device_type: torsh_core::device::DeviceType) -> Self {
543 Self {
544 interpreter: crate::interpreter::GraphInterpreter::new(device_type),
545 checkpoint_manager: None,
546 current_checkpoint: None,
547 checkpoint_frequency: 100, }
549 }
550
551 pub fn with_checkpointing(mut self, manager: CheckpointManager) -> Self {
553 self.checkpoint_manager = Some(manager);
554 self
555 }
556
557 pub fn with_checkpoint_frequency(mut self, frequency: usize) -> Self {
559 self.checkpoint_frequency = frequency;
560 self
561 }
562
563 pub fn run_with_checkpointing(
565 &mut self,
566 graph: &FxGraph,
567 inputs: HashMap<String, Tensor>,
568 ) -> TorshResult<Vec<Tensor>> {
569 if let Some(manager) = &self.checkpoint_manager {
571 if let Ok(Some(checkpoint_data)) = manager.load_latest_checkpoint() {
572 if let Ok(execution_checkpoint) =
573 self.extract_execution_checkpoint(&checkpoint_data)
574 {
575 return self.resume_execution(execution_checkpoint);
576 }
577 }
578 }
579
580 self.start_fresh_execution(graph, inputs)
582 }
583
584 fn start_fresh_execution(
586 &mut self,
587 graph: &FxGraph,
588 inputs: HashMap<String, Tensor>,
589 ) -> TorshResult<Vec<Tensor>> {
590 let start_time = SystemTime::now()
591 .duration_since(UNIX_EPOCH)
592 .unwrap_or_default()
593 .as_secs();
594
595 let mut tensor_states = HashMap::new();
597 for (name, tensor) in &inputs {
598 tensor_states.insert(name.clone(), TensorState::from_tensor(tensor)?);
599 }
600
601 let execution_state = ExecutionState {
602 current_node: None,
603 completed_nodes: vec![],
604 failed_nodes: vec![],
605 start_time,
606 elapsed_time: 0,
607 };
608
609 let checkpoint = ExecutionCheckpoint {
610 graph: graph.clone(),
611 execution_state,
612 inputs: tensor_states,
613 intermediate_results: HashMap::new(),
614 remaining_nodes: graph.nodes().map(|(idx, _)| format!("{idx:?}")).collect(),
615 metadata: CheckpointMetadata::new(0, "execution_checkpoint".to_string()),
616 };
617
618 self.current_checkpoint = Some(checkpoint);
619
620 self.execute_with_checkpoints(inputs)
622 }
623
624 fn resume_execution(&mut self, checkpoint: ExecutionCheckpoint) -> TorshResult<Vec<Tensor>> {
626 self.current_checkpoint = Some(checkpoint);
627
628 let mut inputs = HashMap::new();
630 if let Some(ref checkpoint) = self.current_checkpoint {
631 for (name, tensor_state) in &checkpoint.inputs {
632 inputs.insert(name.clone(), tensor_state.to_tensor()?);
633 }
634 }
635
636 self.execute_with_checkpoints(inputs)
637 }
638
639 fn execute_with_checkpoints(
641 &mut self,
642 inputs: HashMap<String, Tensor>,
643 ) -> TorshResult<Vec<Tensor>> {
644 self.interpreter.run(
647 &self
648 .current_checkpoint
649 .as_ref()
650 .expect("checkpoint should be set before execution")
651 .graph,
652 inputs,
653 )
654 }
655
656 fn extract_execution_checkpoint(
658 &self,
659 _data: &CheckpointData,
660 ) -> TorshResult<ExecutionCheckpoint> {
661 Err(TorshError::InvalidArgument(
663 "No execution checkpoint found".to_string(),
664 ))
665 }
666
667 pub fn save_execution_checkpoint(&mut self) -> TorshResult<()> {
669 if let (Some(manager), Some(checkpoint)) =
670 (&mut self.checkpoint_manager, &self.current_checkpoint)
671 {
672 let checkpoint_data = CheckpointData {
673 graph: checkpoint.graph.clone(),
674 tensor_states: HashMap::new(), optimizer_states: HashMap::new(),
676 rng_states: HashMap::new(),
677 custom_states: HashMap::new(),
678 metadata: checkpoint.metadata.clone(),
679 };
680
681 manager.save_checkpoint(checkpoint_data, Some("execution.ckpt".to_string()))?;
682 }
683
684 Ok(())
685 }
686}
687
688pub fn create_checkpoint(
691 graph: &FxGraph,
692 tensors: HashMap<String, Tensor>,
693 step: u64,
694 loss: Option<f64>,
695) -> TorshResult<CheckpointData> {
696 let mut tensor_states = HashMap::new();
697 for (name, tensor) in tensors {
698 tensor_states.insert(name, TensorState::from_tensor(&tensor)?);
699 }
700
701 let mut metadata = CheckpointMetadata::new(step, "graph_checkpoint".to_string());
702 if let Some(loss_val) = loss {
703 metadata = metadata.with_loss(loss_val);
704 }
705
706 Ok(CheckpointData {
707 graph: graph.clone(),
708 tensor_states,
709 optimizer_states: HashMap::new(),
710 rng_states: HashMap::new(),
711 custom_states: HashMap::new(),
712 metadata,
713 })
714}
715
716pub fn save_checkpoint<P: AsRef<Path>>(
718 path: P,
719 data: CheckpointData,
720 options: Option<CheckpointOptions>,
721) -> TorshResult<()> {
722 let options = options.unwrap_or_default();
723 let mut manager =
724 CheckpointManager::new(path.as_ref().parent().unwrap_or(Path::new(".")), options)?;
725
726 let filename = path
727 .as_ref()
728 .file_name()
729 .and_then(|name| name.to_str())
730 .unwrap_or("checkpoint.ckpt")
731 .to_string();
732
733 manager.save_checkpoint(data, Some(filename))?;
734 Ok(())
735}
736
737pub fn load_checkpoint<P: AsRef<Path>>(
739 path: P,
740 options: Option<CheckpointOptions>,
741) -> TorshResult<CheckpointData> {
742 let options = options.unwrap_or_default();
743 let manager =
744 CheckpointManager::new(path.as_ref().parent().unwrap_or(Path::new(".")), options)?;
745
746 manager.load_checkpoint(path)
747}
748
749#[cfg(test)]
750mod tests {
751 use super::*;
752 use crate::tracer::ModuleTracer;
753 use tempfile::TempDir;
754 use torsh_tensor::creation::ones;
755
756 #[test]
757 fn test_checkpoint_metadata() {
758 let metadata = CheckpointMetadata::new(100, "test_model".to_string())
759 .with_loss(0.5)
760 .with_metadata("epoch".to_string(), "10".to_string());
761
762 assert_eq!(metadata.step, 100);
763 assert_eq!(metadata.loss, Some(0.5));
764 assert_eq!(metadata.user_metadata.get("epoch"), Some(&"10".to_string()));
765 }
766
767 #[test]
768 fn test_tensor_state_serialization() {
769 let tensor = ones(&[2, 3]).unwrap();
770 let state = TensorState::from_tensor(&tensor).unwrap();
771
772 assert_eq!(state.shape, vec![2, 3]);
773 assert_eq!(state.dtype, format!("{:?}", tensor.dtype()));
774
775 let restored = state.to_tensor().unwrap();
776 assert_eq!(restored.shape().dims(), &[2, 3]);
777 }
778
779 #[test]
780 fn test_optimizer_state() {
781 let state = OptimizerState::new("adam".to_string(), 0.001)
782 .with_parameter("beta1".to_string(), 0.9)
783 .with_parameter("beta2".to_string(), 0.999);
784
785 assert_eq!(state.optimizer_type, "adam");
786 assert_eq!(state.learning_rate, 0.001);
787 assert_eq!(state.parameters.get("beta1"), Some(&0.9));
788 }
789
790 #[test]
791 fn test_checkpoint_manager_creation() {
792 let temp_dir = TempDir::new().unwrap();
793 let options = CheckpointOptions::default();
794
795 let result = CheckpointManager::new(temp_dir.path(), options);
796 assert!(result.is_ok());
797 }
798
799 #[test]
800 fn test_checkpoint_save_load() {
801 let temp_dir = TempDir::new().unwrap();
802 let options = CheckpointOptions::default();
803 let mut manager = CheckpointManager::new(temp_dir.path(), options).unwrap();
804
805 let mut tracer = ModuleTracer::new();
807 tracer.add_input("x");
808 tracer.add_call("relu", vec!["x".to_string()]);
809 tracer.add_output("node_0");
810 let graph = tracer.finalize();
811
812 let tensor = ones(&[2, 3]).unwrap();
813 let checkpoint = create_checkpoint(
814 &graph,
815 vec![("x".to_string(), tensor)].into_iter().collect(),
816 100,
817 Some(0.5),
818 )
819 .unwrap();
820
821 let saved_path = manager.save_checkpoint(checkpoint.clone(), None).unwrap();
823 assert!(saved_path.exists());
824
825 let loaded = manager.load_checkpoint(&saved_path).unwrap();
827 assert!(loaded.metadata.step == 0); assert!(loaded.metadata.loss.is_none()); }
832
833 #[test]
834 fn test_checkpoint_compression() {
835 let temp_dir = TempDir::new().unwrap();
836 let options = CheckpointOptions {
837 compress: true,
838 ..Default::default()
839 };
840 let manager = CheckpointManager::new(temp_dir.path(), options).unwrap();
841
842 let test_data = vec![1u8; 1000]; let compressed = manager.compress_data(&test_data).unwrap();
845 let decompressed = manager.decompress_data(&compressed).unwrap();
846
847 assert_eq!(test_data, decompressed);
848 assert!(compressed.len() < test_data.len()); }
850
851 #[test]
852 fn test_resumable_interpreter() {
853 let interpreter = ResumableInterpreter::new(torsh_core::device::DeviceType::Cpu);
854
855 assert_eq!(interpreter.checkpoint_frequency, 100);
857 }
858
859 #[test]
860 fn test_checkpoint_history_management() {
861 let temp_dir = TempDir::new().unwrap();
862 let options = CheckpointOptions {
863 max_history: Some(2),
864 ..Default::default()
865 };
866 let mut manager = CheckpointManager::new(temp_dir.path(), options).unwrap();
867
868 let mut tracer = ModuleTracer::new();
870 tracer.add_input("x");
871 let graph = tracer.finalize();
872
873 let checkpoint = CheckpointData {
874 graph,
875 tensor_states: HashMap::new(),
876 optimizer_states: HashMap::new(),
877 rng_states: HashMap::new(),
878 custom_states: HashMap::new(),
879 metadata: CheckpointMetadata::new(0, "test".to_string()),
880 };
881
882 manager
884 .save_checkpoint(checkpoint.clone(), Some("ckpt1.ckpt".to_string()))
885 .unwrap();
886 manager
887 .save_checkpoint(checkpoint.clone(), Some("ckpt2.ckpt".to_string()))
888 .unwrap();
889 manager
890 .save_checkpoint(checkpoint.clone(), Some("ckpt3.ckpt".to_string()))
891 .unwrap();
892
893 let history = manager.list_checkpoints();
895 assert!(history.len() <= 2);
896 }
897
898 #[test]
899 fn test_checkpoint_formats() {
900 let temp_dir = TempDir::new().unwrap();
901
902 for format in &[CheckpointFormat::Binary, CheckpointFormat::Json] {
903 let options = CheckpointOptions {
904 format: *format,
905 compress: false,
906 ..Default::default()
907 };
908
909 let mut manager = CheckpointManager::new(temp_dir.path(), options).unwrap();
910
911 let mut tracer = ModuleTracer::new();
912 tracer.add_input("x");
913 let graph = tracer.finalize();
914
915 let checkpoint = CheckpointData {
916 graph,
917 tensor_states: HashMap::new(),
918 optimizer_states: HashMap::new(),
919 rng_states: HashMap::new(),
920 custom_states: HashMap::new(),
921 metadata: CheckpointMetadata::new(0, "test".to_string()),
922 };
923
924 let saved_path = manager.save_checkpoint(checkpoint.clone(), None).unwrap();
925 let loaded = manager.load_checkpoint(&saved_path).unwrap();
926
927 assert_eq!(loaded.metadata.step, checkpoint.metadata.step);
928 }
929 }
930
931 #[test]
932 fn test_execution_checkpoint() {
933 let mut tracer = ModuleTracer::new();
934 tracer.add_input("x");
935 tracer.add_call("relu", vec!["x".to_string()]);
936 let graph = tracer.finalize();
937
938 let execution_state = ExecutionState {
939 current_node: None,
940 completed_nodes: vec![],
941 failed_nodes: vec![],
942 start_time: 0,
943 elapsed_time: 0,
944 };
945
946 let checkpoint = ExecutionCheckpoint {
947 graph,
948 execution_state,
949 inputs: HashMap::new(),
950 intermediate_results: HashMap::new(),
951 remaining_nodes: vec![],
952 metadata: CheckpointMetadata::new(0, "execution".to_string()),
953 };
954
955 assert_eq!(checkpoint.metadata.step, 0);
957 assert_eq!(checkpoint.metadata.model_info, "execution");
958 }
959
960 #[test]
961 fn test_utility_functions() {
962 let mut tracer = ModuleTracer::new();
963 tracer.add_input("x");
964 let graph = tracer.finalize();
965
966 let tensor = ones(&[2, 3]).unwrap();
967 let tensors = vec![("x".to_string(), tensor)].into_iter().collect();
968
969 let checkpoint = create_checkpoint(&graph, tensors, 50, Some(0.25)).unwrap();
970
971 assert_eq!(checkpoint.metadata.step, 50);
972 assert_eq!(checkpoint.metadata.loss, Some(0.25));
973 assert!(checkpoint.tensor_states.contains_key("x"));
974 }
975}