serialization_basics/
serialization_basics.rs

1//! Serialization Basics Example
2//!
3//! This example demonstrates how to save and load tensors and optimizers in Train Station:
4//! - Tensor serialization to JSON and binary formats
5//! - Optimizer state persistence
6//! - Format comparison and performance characteristics
7//! - Model checkpointing workflows
8//! - Error handling for serialization operations
9//!
10//! # Learning Objectives
11//!
12//! - Understand tensor and optimizer serialization
13//! - Learn to save and load model states
14//! - Compare different serialization formats
15//! - Implement basic checkpointing workflows
16//! - Handle serialization errors gracefully
17//!
18//! # Prerequisites
19//!
20//! - Basic Rust knowledge
21//! - Understanding of tensor basics (see tensor_basics.rs)
22//! - Familiarity with file I/O operations
23//!
24//! # Usage
25//!
26//! ```bash
27//! cargo run --example serialization_basics
28//! ```
29
30use std::fs;
31use train_station::{
32    optimizers::{Adam, AdamConfig, Optimizer},
33    serialization::StructSerializable,
34    Tensor,
35};
36
37fn main() -> Result<(), Box<dyn std::error::Error>> {
38    println!("=== Serialization Basics Example ===\n");
39
40    demonstrate_tensor_serialization()?;
41    demonstrate_optimizer_serialization()?;
42    demonstrate_format_comparison()?;
43    demonstrate_model_checkpointing()?;
44    demonstrate_error_handling()?;
45    cleanup_temp_files()?;
46
47    println!("\n=== Example completed successfully! ===");
48    Ok(())
49}
50
51/// Demonstrate basic tensor serialization and deserialization
52fn demonstrate_tensor_serialization() -> Result<(), Box<dyn std::error::Error>> {
53    println!("--- Tensor Serialization ---");
54
55    // Create a tensor with some data
56    let original_tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]).unwrap();
57    println!(
58        "Original tensor: shape {:?}, data: {:?}",
59        original_tensor.shape().dims,
60        original_tensor.data()
61    );
62
63    // Save tensor in JSON format
64    let json_path = "temp_tensor.json";
65    original_tensor.save_json(json_path)?;
66    println!("Saved tensor to JSON: {}", json_path);
67
68    // Load tensor from JSON
69    let loaded_tensor_json = Tensor::load_json(json_path)?;
70    println!(
71        "Loaded from JSON: shape {:?}, data: {:?}",
72        loaded_tensor_json.shape().dims,
73        loaded_tensor_json.data()
74    );
75
76    // Verify data integrity
77    assert_eq!(
78        original_tensor.shape().dims,
79        loaded_tensor_json.shape().dims
80    );
81    assert_eq!(original_tensor.data(), loaded_tensor_json.data());
82    println!("JSON serialization verification: PASSED");
83
84    // Save tensor in binary format
85    let binary_path = "temp_tensor.bin";
86    original_tensor.save_binary(binary_path)?;
87    println!("Saved tensor to binary: {}", binary_path);
88
89    // Load tensor from binary
90    let loaded_tensor_binary = Tensor::load_binary(binary_path)?;
91    println!(
92        "Loaded from binary: shape {:?}, data: {:?}",
93        loaded_tensor_binary.shape().dims,
94        loaded_tensor_binary.data()
95    );
96
97    // Verify data integrity
98    assert_eq!(
99        original_tensor.shape().dims,
100        loaded_tensor_binary.shape().dims
101    );
102    assert_eq!(original_tensor.data(), loaded_tensor_binary.data());
103    println!("Binary serialization verification: PASSED");
104
105    Ok(())
106}
107
108/// Demonstrate optimizer serialization and deserialization
109fn demonstrate_optimizer_serialization() -> Result<(), Box<dyn std::error::Error>> {
110    println!("\n--- Optimizer Serialization ---");
111
112    // Create an optimizer with some parameters
113    let mut weight = Tensor::randn(vec![2, 2], Some(42)).with_requires_grad();
114    let mut bias = Tensor::randn(vec![2], Some(43)).with_requires_grad();
115
116    let config = AdamConfig {
117        learning_rate: 0.001,
118        beta1: 0.9,
119        beta2: 0.999,
120        eps: 1e-8,
121        weight_decay: 0.0,
122        amsgrad: false,
123    };
124
125    let mut optimizer = Adam::with_config(config);
126    optimizer.add_parameter(&weight);
127    optimizer.add_parameter(&bias);
128
129    println!(
130        "Created optimizer with {} parameters",
131        optimizer.parameter_count()
132    );
133    println!("Learning rate: {}", optimizer.learning_rate());
134
135    // Simulate some training steps
136    for _ in 0..3 {
137        let mut loss = weight.sum() + bias.sum();
138        loss.backward(None);
139        optimizer.step(&mut [&mut weight, &mut bias]);
140        optimizer.zero_grad(&mut [&mut weight, &mut bias]);
141    }
142
143    // Save optimizer state
144    let optimizer_path = "temp_optimizer.json";
145    optimizer.save_json(optimizer_path)?;
146    println!("Saved optimizer to: {}", optimizer_path);
147
148    // Load optimizer state
149    let loaded_optimizer = Adam::load_json(optimizer_path)?;
150    println!(
151        "Loaded optimizer with {} parameters",
152        loaded_optimizer.parameter_count()
153    );
154    println!("Learning rate: {}", loaded_optimizer.learning_rate());
155
156    // Verify optimizer state
157    assert_eq!(
158        optimizer.parameter_count(),
159        loaded_optimizer.parameter_count()
160    );
161    assert_eq!(optimizer.learning_rate(), loaded_optimizer.learning_rate());
162    println!("Optimizer serialization verification: PASSED");
163
164    Ok(())
165}
166
167/// Demonstrate format comparison and performance characteristics
168fn demonstrate_format_comparison() -> Result<(), Box<dyn std::error::Error>> {
169    println!("\n--- Format Comparison ---");
170
171    // Create a larger tensor for comparison
172    let tensor = Tensor::randn(vec![10, 10], Some(44));
173
174    // Save in both formats
175    tensor.save_json("temp_comparison.json")?;
176    tensor.save_binary("temp_comparison.bin")?;
177
178    // Compare file sizes
179    let json_size = fs::metadata("temp_comparison.json")?.len();
180    let binary_size = fs::metadata("temp_comparison.bin")?.len();
181
182    println!("JSON file size: {} bytes", json_size);
183    println!("Binary file size: {} bytes", binary_size);
184    println!(
185        "Compression ratio: {:.2}x",
186        json_size as f64 / binary_size as f64
187    );
188
189    // Load and verify both formats
190    let json_tensor = Tensor::load_json("temp_comparison.json")?;
191    let binary_tensor = Tensor::load_binary("temp_comparison.bin")?;
192
193    assert_eq!(tensor.shape().dims, json_tensor.shape().dims);
194    assert_eq!(tensor.shape().dims, binary_tensor.shape().dims);
195    assert_eq!(tensor.data(), json_tensor.data());
196    assert_eq!(tensor.data(), binary_tensor.data());
197
198    println!("Format comparison verification: PASSED");
199
200    Ok(())
201}
202
203/// Demonstrate a basic model checkpointing workflow
204fn demonstrate_model_checkpointing() -> Result<(), Box<dyn std::error::Error>> {
205    println!("\n--- Model Checkpointing ---");
206
207    // Create a simple model (weights and bias)
208    let mut weights = Tensor::randn(vec![2, 1], Some(45)).with_requires_grad();
209    let mut bias = Tensor::randn(vec![1], Some(46)).with_requires_grad();
210
211    // Create optimizer
212    let mut optimizer = Adam::with_learning_rate(0.01);
213    optimizer.add_parameter(&weights);
214    optimizer.add_parameter(&bias);
215
216    println!("Initial weights: {:?}", weights.data());
217    println!("Initial bias: {:?}", bias.data());
218
219    // Simulate training
220    for epoch in 0..5 {
221        let mut loss = weights.sum() + bias.sum();
222        loss.backward(None);
223        optimizer.step(&mut [&mut weights, &mut bias]);
224        optimizer.zero_grad(&mut [&mut weights, &mut bias]);
225
226        if epoch % 2 == 0 {
227            // Save checkpoint
228            let checkpoint_dir = format!("checkpoint_epoch_{}", epoch);
229            fs::create_dir_all(&checkpoint_dir)?;
230
231            weights.save_json(format!("{}/weights.json", checkpoint_dir))?;
232            bias.save_json(format!("{}/bias.json", checkpoint_dir))?;
233            optimizer.save_json(format!("{}/optimizer.json", checkpoint_dir))?;
234
235            println!("Saved checkpoint for epoch {}", epoch);
236        }
237    }
238
239    // Load from checkpoint
240    let loaded_weights = Tensor::load_json("checkpoint_epoch_4/weights.json")?;
241    let loaded_bias = Tensor::load_json("checkpoint_epoch_4/bias.json")?;
242    let loaded_optimizer = Adam::load_json("checkpoint_epoch_4/optimizer.json")?;
243
244    println!("Loaded weights: {:?}", loaded_weights.data());
245    println!("Loaded bias: {:?}", loaded_bias.data());
246    println!(
247        "Loaded optimizer learning rate: {}",
248        loaded_optimizer.learning_rate()
249    );
250
251    // Verify checkpoint integrity
252    assert_eq!(weights.shape().dims, loaded_weights.shape().dims);
253    assert_eq!(bias.shape().dims, loaded_bias.shape().dims);
254    assert_eq!(optimizer.learning_rate(), loaded_optimizer.learning_rate());
255
256    println!("Checkpointing verification: PASSED");
257
258    Ok(())
259}
260
261/// Demonstrate error handling for serialization operations
262fn demonstrate_error_handling() -> Result<(), Box<dyn std::error::Error>> {
263    println!("\n--- Error Handling ---");
264
265    // Test loading non-existent file
266    match Tensor::load_json("nonexistent_file.json") {
267        Ok(_) => println!("Unexpected: Successfully loaded non-existent file"),
268        Err(e) => println!("Expected error loading non-existent file: {}", e),
269    }
270
271    // Test loading with wrong format
272    let tensor = Tensor::randn(vec![2, 2], Some(47));
273    tensor.save_binary("temp_binary.bin")?;
274
275    match Tensor::load_json("temp_binary.bin") {
276        Ok(_) => println!("Unexpected: Successfully loaded binary as JSON"),
277        Err(e) => println!("Expected error loading binary as JSON: {}", e),
278    }
279
280    // Test loading corrupted file
281    fs::write("temp_invalid.json", "invalid json content")?;
282    match Tensor::load_json("temp_invalid.json") {
283        Ok(_) => println!("Unexpected: Successfully loaded invalid JSON"),
284        Err(e) => println!("Expected error loading invalid JSON: {}", e),
285    }
286
287    println!("Error handling verification: PASSED");
288
289    Ok(())
290}
291
292/// Clean up temporary files created during the example
293fn cleanup_temp_files() -> Result<(), Box<dyn std::error::Error>> {
294    println!("\n--- Cleanup ---");
295
296    let files_to_remove = [
297        "temp_tensor.json",
298        "temp_tensor.bin",
299        "temp_optimizer.json",
300        "temp_comparison.json",
301        "temp_comparison.bin",
302        "temp_binary.bin",
303        "temp_invalid.json",
304    ];
305
306    for file in &files_to_remove {
307        if fs::metadata(file).is_ok() {
308            fs::remove_file(file)?;
309            println!("Removed: {}", file);
310        }
311    }
312
313    // Remove checkpoint directories
314    for epoch in [0, 2, 4] {
315        let checkpoint_dir = format!("checkpoint_epoch_{}", epoch);
316        if fs::metadata(&checkpoint_dir).is_ok() {
317            fs::remove_dir_all(&checkpoint_dir)?;
318            println!("Removed directory: {}", checkpoint_dir);
319        }
320    }
321
322    println!("Cleanup completed");
323    Ok(())
324}
325
326#[cfg(test)]
327mod tests {
328    use super::*;
329
330    /// Test tensor serialization roundtrip
331    #[test]
332    fn test_tensor_serialization_roundtrip() {
333        let original = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
334
335        // Test JSON roundtrip
336        original.save_json("test_tensor.json").unwrap();
337        let loaded = Tensor::load_json("test_tensor.json").unwrap();
338        assert_eq!(original.shape().dims, loaded.shape().dims);
339        assert_eq!(original.data(), loaded.data());
340
341        // Test binary roundtrip
342        original.save_binary("test_tensor.bin").unwrap();
343        let loaded = Tensor::load_binary("test_tensor.bin").unwrap();
344        assert_eq!(original.shape().dims, loaded.shape().dims);
345        assert_eq!(original.data(), loaded.data());
346
347        // Cleanup
348        let _ = fs::remove_file("test_tensor.json");
349        let _ = fs::remove_file("test_tensor.bin");
350    }
351
352    /// Test optimizer serialization roundtrip
353    #[test]
354    fn test_optimizer_serialization_roundtrip() {
355        let mut weight = Tensor::ones(vec![2, 2]).with_requires_grad();
356        let mut optimizer = Adam::new();
357        optimizer.add_parameter(&weight);
358
359        // Simulate a training step
360        let mut loss = weight.sum();
361        loss.backward(None);
362        optimizer.step(&mut [&mut weight]);
363
364        // Test serialization roundtrip
365        optimizer.save_json("test_optimizer.json").unwrap();
366        let loaded = Adam::load_json("test_optimizer.json").unwrap();
367
368        assert_eq!(optimizer.parameter_count(), loaded.parameter_count());
369        assert_eq!(optimizer.learning_rate(), loaded.learning_rate());
370
371        // Cleanup
372        let _ = fs::remove_file("test_optimizer.json");
373    }
374}