1use std::collections::HashMap;
7use std::path::Path;
8
9use yscv_tensor::Tensor;
10
11use super::weights::{load_weights, save_weights};
12use crate::ModelError;
13
14const OPT_PREFIX: &str = "__opt__.";
16
17pub fn save_training_checkpoint(
22 path: &Path,
23 model_weights: &HashMap<String, Tensor>,
24 optimizer_state: &HashMap<String, Tensor>,
25) -> Result<(), ModelError> {
26 let mut combined = model_weights.clone();
27 for (key, tensor) in optimizer_state {
28 combined.insert(format!("{OPT_PREFIX}{key}"), tensor.clone());
29 }
30 save_weights(path, &combined)
31}
32
33pub fn load_training_checkpoint(
37 path: &Path,
38) -> Result<(HashMap<String, Tensor>, HashMap<String, Tensor>), ModelError> {
39 let all = load_weights(path)?;
40 let mut model_weights = HashMap::new();
41 let mut optimizer_state = HashMap::new();
42
43 for (key, tensor) in all {
44 if let Some(stripped) = key.strip_prefix(OPT_PREFIX) {
45 optimizer_state.insert(stripped.to_owned(), tensor);
46 } else {
47 model_weights.insert(key, tensor);
48 }
49 }
50
51 Ok((model_weights, optimizer_state))
52}
53
54pub fn sgd_state_to_map(velocity: &HashMap<u64, Tensor>) -> HashMap<String, Tensor> {
58 velocity
59 .iter()
60 .map(|(id, t)| (format!("sgd.{id}.velocity"), t.clone()))
61 .collect()
62}
63
64pub fn sgd_state_from_map(map: &HashMap<String, Tensor>) -> HashMap<u64, Tensor> {
66 let mut velocity = HashMap::new();
67 for (key, tensor) in map {
68 if let Some(rest) = key.strip_prefix("sgd.")
69 && let Some(id_str) = rest.strip_suffix(".velocity")
70 && let Ok(id) = id_str.parse::<u64>()
71 {
72 velocity.insert(id, tensor.clone());
73 }
74 }
75 velocity
76}
77
78pub fn adam_state_to_map(state: &[(u64, Tensor, Tensor, u64)]) -> HashMap<String, Tensor> {
82 let mut map = HashMap::new();
83 for (id, m, v, step) in state {
84 map.insert(format!("adam.{id}.m"), m.clone());
85 map.insert(format!("adam.{id}.v"), v.clone());
86 map.insert(
88 format!("adam.{id}.step"),
89 Tensor::from_vec(vec![1], vec![*step as f32]).expect("scalar shape matches data"),
90 );
91 }
92 map
93}
94
95pub fn adam_state_from_map(map: &HashMap<String, Tensor>) -> Vec<(u64, Tensor, Tensor, u64)> {
99 let mut ids: Vec<u64> = map
101 .keys()
102 .filter_map(|k| {
103 k.strip_prefix("adam.")
104 .and_then(|rest| rest.strip_suffix(".m"))
105 .and_then(|id_str| id_str.parse::<u64>().ok())
106 })
107 .collect();
108 ids.sort();
109 ids.dedup();
110
111 ids.into_iter()
112 .filter_map(|id| {
113 let m = map.get(&format!("adam.{id}.m"))?.clone();
114 let v = map.get(&format!("adam.{id}.v"))?.clone();
115 let step = map
116 .get(&format!("adam.{id}.step"))
117 .map(|t| t.data()[0] as u64)
118 .unwrap_or(0);
119 Some((id, m, v, step))
120 })
121 .collect()
122}
123
124#[cfg(test)]
125mod tests {
126 use super::*;
127
128 #[test]
129 fn test_save_load_training_checkpoint_roundtrip() {
130 let dir = std::env::temp_dir().join("yscv_checkpoint_test");
131 let _ = std::fs::create_dir_all(&dir);
132 let path = dir.join("checkpoint.bin");
133
134 let mut model_weights = HashMap::new();
135 model_weights.insert(
136 "layer.0.weight".to_string(),
137 Tensor::from_vec(vec![2, 2], vec![1.0, 2.0, 3.0, 4.0]).unwrap(),
138 );
139
140 let mut opt_state = HashMap::new();
141 opt_state.insert(
142 "sgd.0.velocity".to_string(),
143 Tensor::from_vec(vec![2, 2], vec![0.1, 0.2, 0.3, 0.4]).unwrap(),
144 );
145
146 save_training_checkpoint(&path, &model_weights, &opt_state).unwrap();
147
148 let (loaded_weights, loaded_opt) = load_training_checkpoint(&path).unwrap();
149
150 assert!(loaded_weights.contains_key("layer.0.weight"));
151 assert!(loaded_opt.contains_key("sgd.0.velocity"));
152 assert_eq!(
153 loaded_weights["layer.0.weight"].data(),
154 &[1.0, 2.0, 3.0, 4.0]
155 );
156 assert_eq!(loaded_opt["sgd.0.velocity"].data(), &[0.1, 0.2, 0.3, 0.4]);
157
158 let _ = std::fs::remove_file(&path);
159 }
160
161 #[test]
162 fn test_sgd_state_roundtrip() {
163 let mut velocity = HashMap::new();
164 velocity.insert(
165 42u64,
166 Tensor::from_vec(vec![3], vec![1.0, 2.0, 3.0]).unwrap(),
167 );
168
169 let map = sgd_state_to_map(&velocity);
170 let restored = sgd_state_from_map(&map);
171
172 assert!(restored.contains_key(&42));
173 assert_eq!(restored[&42].data(), &[1.0, 2.0, 3.0]);
174 }
175
176 #[test]
177 fn test_adam_state_roundtrip() {
178 let state = vec![(
179 7u64,
180 Tensor::from_vec(vec![2], vec![0.1, 0.2]).unwrap(),
181 Tensor::from_vec(vec![2], vec![0.01, 0.02]).unwrap(),
182 100u64,
183 )];
184
185 let map = adam_state_to_map(&state);
186 let restored = adam_state_from_map(&map);
187
188 assert_eq!(restored.len(), 1);
189 assert_eq!(restored[0].0, 7);
190 assert_eq!(restored[0].3, 100);
191 }
192}