improved_model_serialization/
improved_model_serialization.rs

1use ndarray::Array2;
2use rand::prelude::*;
3use rand::rngs::SmallRng;
4use scirs2_neural::error::Result;
5use scirs2_neural::layers::{Dense, Dropout};
6use scirs2_neural::losses::CrossEntropyLoss;
7use scirs2_neural::models::{Model, Sequential};
8use scirs2_neural::optimizers::Adam;
9use scirs2_neural::serialization::{self, SerializationFormat};
10use std::path::Path;
11
12// Create a simple neural network model for the XOR problem
13fn create_xor_model(rng: &mut SmallRng) -> Result<Sequential<f32>> {
14    let mut model = Sequential::new();
15
16    // XOR problem requires a hidden layer
17    let input_dim = 2;
18    let hidden_dim = 4;
19    let output_dim = 1;
20
21    // Input to hidden layer with ReLU activation
22    let dense1 = Dense::new(input_dim, hidden_dim, Some("relu"), rng)?;
23    model.add_layer(dense1);
24
25    // Optional dropout for regularization (low rate as XOR is small)
26    let dropout = Dropout::new(0.1, rng)?;
27    model.add_layer(dropout);
28
29    // Hidden to output layer with sigmoid activation (binary output)
30    let dense2 = Dense::new(hidden_dim, output_dim, Some("sigmoid"), rng)?;
31    model.add_layer(dense2);
32
33    Ok(model)
34}
35
36// Create XOR dataset
37fn create_xor_dataset() -> (Array2<f32>, Array2<f32>) {
38    // XOR truth table inputs
39    let x = Array2::from_shape_vec(
40        (4, 2),
41        vec![
42            0.0, 0.0, // 0 XOR 0 = 0
43            0.0, 1.0, // 0 XOR 1 = 1
44            1.0, 0.0, // 1 XOR 0 = 1
45            1.0, 1.0, // 1 XOR 1 = 0
46        ],
47    )
48    .unwrap();
49
50    // XOR truth table outputs
51    let y = Array2::from_shape_vec(
52        (4, 1),
53        vec![
54            0.0, // 0 XOR 0 = 0
55            1.0, // 0 XOR 1 = 1
56            1.0, // 1 XOR 0 = 1
57            0.0, // 1 XOR 1 = 0
58        ],
59    )
60    .unwrap();
61
62    (x, y)
63}
64
65// Train the model on XOR problem
66fn train_model(
67    model: &mut Sequential<f32>,
68    x: &Array2<f32>,
69    y: &Array2<f32>,
70    epochs: usize,
71) -> Result<()> {
72    println!("Training XOR model...");
73
74    // Setup loss function and optimizer
75    let loss_fn = CrossEntropyLoss::new(1e-10);
76    let mut optimizer = Adam::new(0.01, 0.9, 0.999, 1e-8);
77
78    // Train for specified number of epochs
79    for epoch in 0..epochs {
80        // Convert to dynamic dimension arrays
81        let x_dyn = x.clone().into_dyn();
82        let y_dyn = y.clone().into_dyn();
83
84        // Perform a training step
85        let loss = model.train_batch(&x_dyn, &y_dyn, &loss_fn, &mut optimizer)?;
86
87        // Print progress every 100 epochs
88        if epoch % 100 == 0 || epoch == epochs - 1 {
89            println!("Epoch {}/{}: loss = {:.6}", epoch + 1, epochs, loss);
90        }
91    }
92
93    println!("Training completed.");
94    Ok(())
95}
96
97// Evaluate model performance on XOR problem
98fn evaluate_model(model: &Sequential<f32>, x: &Array2<f32>, y: &Array2<f32>) -> Result<f32> {
99    let predictions = model.forward(&x.clone().into_dyn())?;
100    let binary_thresh = 0.5;
101
102    println!("\nModel predictions:");
103    println!("-----------------");
104    println!("   X₁   |   X₂   | Target | Prediction | Binary");
105    println!("----------------------------------------------");
106
107    let mut correct = 0;
108    for i in 0..x.shape()[0] {
109        let pred = predictions[[i, 0]];
110        let binary_pred = pred > binary_thresh;
111        let target = y[[i, 0]];
112        let is_correct = (binary_pred as i32 as f32 - target).abs() < 1e-6;
113
114        if is_correct {
115            correct += 1;
116        }
117
118        println!(
119            " {:.4}  | {:.4}  | {:.4}  |   {:.4}   |  {}  {}",
120            x[[i, 0]],
121            x[[i, 1]],
122            target,
123            pred,
124            binary_pred as i32,
125            if is_correct { "✓" } else { "✗" }
126        );
127    }
128
129    let accuracy = correct as f32 / x.shape()[0] as f32;
130    println!(
131        "\nAccuracy: {:.2}% ({}/{})",
132        accuracy * 100.0,
133        correct,
134        x.shape()[0]
135    );
136
137    Ok(accuracy)
138}
139
140// A more realistic dataset with noise to better test serialization
141fn create_noisy_xor_dataset(
142    size: usize,
143    noise_level: f32,
144    rng: &mut SmallRng,
145) -> (Array2<f32>, Array2<f32>) {
146    let mut x = Array2::<f32>::zeros((size, 2));
147    let mut y = Array2::<f32>::zeros((size, 1));
148
149    for i in 0..size {
150        // Generate binary inputs with some randomness
151        let x1 = (rng.random_range(0.0..1.0) > 0.5) as i32 as f32;
152        let x2 = (rng.random_range(0.0..1.0) > 0.5) as i32 as f32;
153
154        // Add some noise to inputs
155        x[[i, 0]] = x1 + rng.random_range(-noise_level / 2.0..noise_level / 2.0);
156        x[[i, 1]] = x2 + rng.random_range(-noise_level / 2.0..noise_level / 2.0);
157
158        // Standard XOR calculation for target
159        y[[i, 0]] = (x1 as i32 ^ x2 as i32) as f32;
160    }
161
162    (x, y)
163}
164
165fn main() -> Result<()> {
166    println!("Improved Model Serialization and Loading Example");
167    println!("===============================================\n");
168
169    // Initialize random number generator
170    let mut rng = SmallRng::seed_from_u64(42);
171
172    // 1. Create XOR datasets
173    let (x_train, y_train) = create_xor_dataset();
174    println!("XOR dataset created");
175
176    // 2. Create and train the model
177    let mut model = create_xor_model(&mut rng)?;
178    println!("Model created with {} layers", model.num_layers());
179
180    // Train the model
181    train_model(&mut model, &x_train, &y_train, 2000)?;
182
183    // 3. Evaluate the model before saving
184    println!("\nEvaluating model before saving:");
185    evaluate_model(&model, &x_train, &y_train)?;
186
187    // 4. Save the model in different formats
188    println!("\nSaving model in different formats...");
189
190    // Save in JSON format (human-readable)
191    let json_path = Path::new("xor_model.json");
192    serialization::save_model(&model, json_path, SerializationFormat::JSON)?;
193    println!("Model saved to {} in JSON format", json_path.display());
194
195    // Save in CBOR format (compact binary)
196    let cbor_path = Path::new("xor_model.cbor");
197    serialization::save_model(&model, cbor_path, SerializationFormat::CBOR)?;
198    println!("Model saved to {} in CBOR format", cbor_path.display());
199
200    // Save in MessagePack format (efficient binary)
201    let msgpack_path = Path::new("xor_model.msgpack");
202    serialization::save_model(&model, msgpack_path, SerializationFormat::MessagePack)?;
203    println!(
204        "Model saved to {} in MessagePack format",
205        msgpack_path.display()
206    );
207
208    // 5. Load models from each format and evaluate
209    println!("\nLoading and evaluating models from each format:");
210
211    // Load and evaluate JSON model
212    println!("\n--- JSON Format ---");
213    let json_model = serialization::load_model::<f32, _>(json_path, SerializationFormat::JSON)?;
214    println!("JSON model loaded with {} layers", json_model.num_layers());
215    evaluate_model(&json_model, &x_train, &y_train)?;
216
217    // Load and evaluate CBOR model
218    println!("\n--- CBOR Format ---");
219    let cbor_model = serialization::load_model::<f32, _>(cbor_path, SerializationFormat::CBOR)?;
220    println!("CBOR model loaded with {} layers", cbor_model.num_layers());
221    evaluate_model(&cbor_model, &x_train, &y_train)?;
222
223    // Load and evaluate MessagePack model
224    println!("\n--- MessagePack Format ---");
225    let msgpack_model =
226        serialization::load_model::<f32, _>(msgpack_path, SerializationFormat::MessagePack)?;
227    println!(
228        "MessagePack model loaded with {} layers",
229        msgpack_model.num_layers()
230    );
231    evaluate_model(&msgpack_model, &x_train, &y_train)?;
232
233    // 6. Test with a larger, noisy dataset to verify model works with unseen data
234    println!("\nTesting with larger, noisy dataset:");
235    let (x_test, y_test) = create_noisy_xor_dataset(100, 0.2, &mut rng);
236    evaluate_model(&model, &x_test, &y_test)?;
237
238    // File sizes for comparison
239    let json_size = std::fs::metadata(json_path)?.len();
240    let cbor_size = std::fs::metadata(cbor_path)?.len();
241    let msgpack_size = std::fs::metadata(msgpack_path)?.len();
242
243    println!("\nSerialization Format Comparison:");
244    println!("  JSON:       {} bytes", json_size);
245    println!(
246        "  CBOR:       {} bytes ({:.1}% of JSON)",
247        cbor_size,
248        (cbor_size as f64 / json_size as f64) * 100.0
249    );
250    println!(
251        "  MessagePack: {} bytes ({:.1}% of JSON)",
252        msgpack_size,
253        (msgpack_size as f64 / json_size as f64) * 100.0
254    );
255
256    println!("\nModel serialization and loading example completed successfully!");
257    Ok(())
258}