1use std::fs;
31use train_station::{
32 optimizers::{Adam, AdamConfig, Optimizer},
33 serialization::StructSerializable,
34 Tensor,
35};
36
37fn main() -> Result<(), Box<dyn std::error::Error>> {
38 println!("=== Serialization Basics Example ===\n");
39
40 demonstrate_tensor_serialization()?;
41 demonstrate_optimizer_serialization()?;
42 demonstrate_format_comparison()?;
43 demonstrate_model_checkpointing()?;
44 demonstrate_error_handling()?;
45 cleanup_temp_files()?;
46
47 println!("\n=== Example completed successfully! ===");
48 Ok(())
49}
50
51fn demonstrate_tensor_serialization() -> Result<(), Box<dyn std::error::Error>> {
53 println!("--- Tensor Serialization ---");
54
55 let original_tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]).unwrap();
57 println!(
58 "Original tensor: shape {:?}, data: {:?}",
59 original_tensor.shape().dims(),
60 original_tensor.data()
61 );
62
63 let json_path = "temp_tensor.json";
65 original_tensor.save_json(json_path)?;
66 println!("Saved tensor to JSON: {}", json_path);
67
68 let loaded_tensor_json = Tensor::load_json(json_path)?;
70 println!(
71 "Loaded from JSON: shape {:?}, data: {:?}",
72 loaded_tensor_json.shape().dims(),
73 loaded_tensor_json.data()
74 );
75
76 assert_eq!(
78 original_tensor.shape().dims(),
79 loaded_tensor_json.shape().dims()
80 );
81 assert_eq!(original_tensor.data(), loaded_tensor_json.data());
82 println!("JSON serialization verification: PASSED");
83
84 let binary_path = "temp_tensor.bin";
86 original_tensor.save_binary(binary_path)?;
87 println!("Saved tensor to binary: {}", binary_path);
88
89 let loaded_tensor_binary = Tensor::load_binary(binary_path)?;
91 println!(
92 "Loaded from binary: shape {:?}, data: {:?}",
93 loaded_tensor_binary.shape().dims(),
94 loaded_tensor_binary.data()
95 );
96
97 assert_eq!(
99 original_tensor.shape().dims(),
100 loaded_tensor_binary.shape().dims()
101 );
102 assert_eq!(original_tensor.data(), loaded_tensor_binary.data());
103 println!("Binary serialization verification: PASSED");
104
105 Ok(())
106}
107
108fn demonstrate_optimizer_serialization() -> Result<(), Box<dyn std::error::Error>> {
110 println!("\n--- Optimizer Serialization ---");
111
112 let mut weight = Tensor::randn(vec![2, 2], Some(42)).with_requires_grad();
114 let mut bias = Tensor::randn(vec![2], Some(43)).with_requires_grad();
115
116 let config = AdamConfig {
117 learning_rate: 0.001,
118 beta1: 0.9,
119 beta2: 0.999,
120 eps: 1e-8,
121 weight_decay: 0.0,
122 amsgrad: false,
123 };
124
125 let mut optimizer = Adam::with_config(config);
126 optimizer.add_parameter(&weight);
127 optimizer.add_parameter(&bias);
128
129 println!(
130 "Created optimizer with {} parameters",
131 optimizer.parameter_count()
132 );
133 println!("Learning rate: {}", optimizer.learning_rate());
134
135 for _ in 0..3 {
137 let mut loss = weight.sum() + bias.sum();
138 loss.backward(None);
139 optimizer.step(&mut [&mut weight, &mut bias]);
140 optimizer.zero_grad(&mut [&mut weight, &mut bias]);
141 }
142
143 let optimizer_path = "temp_optimizer.json";
145 optimizer.save_json(optimizer_path)?;
146 println!("Saved optimizer to: {}", optimizer_path);
147
148 let loaded_optimizer = Adam::load_json(optimizer_path)?;
150 println!(
151 "Loaded optimizer with {} parameters",
152 loaded_optimizer.parameter_count()
153 );
154 println!("Learning rate: {}", loaded_optimizer.learning_rate());
155
156 assert_eq!(
158 optimizer.parameter_count(),
159 loaded_optimizer.parameter_count()
160 );
161 assert_eq!(optimizer.learning_rate(), loaded_optimizer.learning_rate());
162 println!("Optimizer serialization verification: PASSED");
163
164 Ok(())
165}
166
167fn demonstrate_format_comparison() -> Result<(), Box<dyn std::error::Error>> {
169 println!("\n--- Format Comparison ---");
170
171 let tensor = Tensor::randn(vec![10, 10], Some(44));
173
174 tensor.save_json("temp_comparison.json")?;
176 tensor.save_binary("temp_comparison.bin")?;
177
178 let json_size = fs::metadata("temp_comparison.json")?.len();
180 let binary_size = fs::metadata("temp_comparison.bin")?.len();
181
182 println!("JSON file size: {} bytes", json_size);
183 println!("Binary file size: {} bytes", binary_size);
184 println!(
185 "Compression ratio: {:.2}x",
186 json_size as f64 / binary_size as f64
187 );
188
189 let json_tensor = Tensor::load_json("temp_comparison.json")?;
191 let binary_tensor = Tensor::load_binary("temp_comparison.bin")?;
192
193 assert_eq!(tensor.shape().dims(), json_tensor.shape().dims());
194 assert_eq!(tensor.shape().dims(), binary_tensor.shape().dims());
195 assert_eq!(tensor.data(), json_tensor.data());
196 assert_eq!(tensor.data(), binary_tensor.data());
197
198 println!("Format comparison verification: PASSED");
199
200 Ok(())
201}
202
203fn demonstrate_model_checkpointing() -> Result<(), Box<dyn std::error::Error>> {
205 println!("\n--- Model Checkpointing ---");
206
207 let mut weights = Tensor::randn(vec![2, 1], Some(45)).with_requires_grad();
209 let mut bias = Tensor::randn(vec![1], Some(46)).with_requires_grad();
210
211 let mut optimizer = Adam::with_learning_rate(0.01);
213 optimizer.add_parameter(&weights);
214 optimizer.add_parameter(&bias);
215
216 println!("Initial weights: {:?}", weights.data());
217 println!("Initial bias: {:?}", bias.data());
218
219 for epoch in 0..5 {
221 let mut loss = weights.sum() + bias.sum();
222 loss.backward(None);
223 optimizer.step(&mut [&mut weights, &mut bias]);
224 optimizer.zero_grad(&mut [&mut weights, &mut bias]);
225
226 if epoch % 2 == 0 {
227 let checkpoint_dir = format!("checkpoint_epoch_{}", epoch);
229 fs::create_dir_all(&checkpoint_dir)?;
230
231 weights.save_json(format!("{}/weights.json", checkpoint_dir))?;
232 bias.save_json(format!("{}/bias.json", checkpoint_dir))?;
233 optimizer.save_json(format!("{}/optimizer.json", checkpoint_dir))?;
234
235 println!("Saved checkpoint for epoch {}", epoch);
236 }
237 }
238
239 let loaded_weights = Tensor::load_json("checkpoint_epoch_4/weights.json")?;
241 let loaded_bias = Tensor::load_json("checkpoint_epoch_4/bias.json")?;
242 let loaded_optimizer = Adam::load_json("checkpoint_epoch_4/optimizer.json")?;
243
244 println!("Loaded weights: {:?}", loaded_weights.data());
245 println!("Loaded bias: {:?}", loaded_bias.data());
246 println!(
247 "Loaded optimizer learning rate: {}",
248 loaded_optimizer.learning_rate()
249 );
250
251 assert_eq!(weights.shape().dims(), loaded_weights.shape().dims());
253 assert_eq!(bias.shape().dims(), loaded_bias.shape().dims());
254 assert_eq!(optimizer.learning_rate(), loaded_optimizer.learning_rate());
255
256 println!("Checkpointing verification: PASSED");
257
258 Ok(())
259}
260
261fn demonstrate_error_handling() -> Result<(), Box<dyn std::error::Error>> {
263 println!("\n--- Error Handling ---");
264
265 match Tensor::load_json("nonexistent_file.json") {
267 Ok(_) => println!("Unexpected: Successfully loaded non-existent file"),
268 Err(e) => println!("Expected error loading non-existent file: {}", e),
269 }
270
271 let tensor = Tensor::randn(vec![2, 2], Some(47));
273 tensor.save_binary("temp_binary.bin")?;
274
275 match Tensor::load_json("temp_binary.bin") {
276 Ok(_) => println!("Unexpected: Successfully loaded binary as JSON"),
277 Err(e) => println!("Expected error loading binary as JSON: {}", e),
278 }
279
280 fs::write("temp_invalid.json", "invalid json content")?;
282 match Tensor::load_json("temp_invalid.json") {
283 Ok(_) => println!("Unexpected: Successfully loaded invalid JSON"),
284 Err(e) => println!("Expected error loading invalid JSON: {}", e),
285 }
286
287 println!("Error handling verification: PASSED");
288
289 Ok(())
290}
291
292fn cleanup_temp_files() -> Result<(), Box<dyn std::error::Error>> {
294 println!("\n--- Cleanup ---");
295
296 let files_to_remove = [
297 "temp_tensor.json",
298 "temp_tensor.bin",
299 "temp_optimizer.json",
300 "temp_comparison.json",
301 "temp_comparison.bin",
302 "temp_binary.bin",
303 "temp_invalid.json",
304 ];
305
306 for file in &files_to_remove {
307 if fs::metadata(file).is_ok() {
308 fs::remove_file(file)?;
309 println!("Removed: {}", file);
310 }
311 }
312
313 for epoch in [0, 2, 4] {
315 let checkpoint_dir = format!("checkpoint_epoch_{}", epoch);
316 if fs::metadata(&checkpoint_dir).is_ok() {
317 fs::remove_dir_all(&checkpoint_dir)?;
318 println!("Removed directory: {}", checkpoint_dir);
319 }
320 }
321
322 println!("Cleanup completed");
323 Ok(())
324}
325
326#[cfg(test)]
327mod tests {
328 use super::*;
329
330 #[test]
332 fn test_tensor_serialization_roundtrip() {
333 let original = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
334
335 original.save_json("test_tensor.json").unwrap();
337 let loaded = Tensor::load_json("test_tensor.json").unwrap();
338 assert_eq!(original.shape().dims(), loaded.shape().dims());
339 assert_eq!(original.data(), loaded.data());
340
341 original.save_binary("test_tensor.bin").unwrap();
343 let loaded = Tensor::load_binary("test_tensor.bin").unwrap();
344 assert_eq!(original.shape().dims(), loaded.shape().dims());
345 assert_eq!(original.data(), loaded.data());
346
347 let _ = fs::remove_file("test_tensor.json");
349 let _ = fs::remove_file("test_tensor.bin");
350 }
351
352 #[test]
354 fn test_optimizer_serialization_roundtrip() {
355 let mut weight = Tensor::ones(vec![2, 2]).with_requires_grad();
356 let mut optimizer = Adam::new();
357 optimizer.add_parameter(&weight);
358
359 let mut loss = weight.sum();
361 loss.backward(None);
362 optimizer.step(&mut [&mut weight]);
363
364 optimizer.save_json("test_optimizer.json").unwrap();
366 let loaded = Adam::load_json("test_optimizer.json").unwrap();
367
368 assert_eq!(optimizer.parameter_count(), loaded.parameter_count());
369 assert_eq!(optimizer.learning_rate(), loaded.learning_rate());
370
371 let _ = fs::remove_file("test_optimizer.json");
373 }
374}