improved_model_serialization/
improved_model_serialization.rs1use ndarray::Array2;
2use rand::prelude::*;
3use rand::rngs::SmallRng;
4use scirs2_neural::error::Result;
5use scirs2_neural::layers::{Dense, Dropout};
6use scirs2_neural::losses::CrossEntropyLoss;
7use scirs2_neural::models::{Model, Sequential};
8use scirs2_neural::optimizers::Adam;
9use scirs2_neural::serialization::{self, SerializationFormat};
10use std::path::Path;
11
12fn create_xor_model(rng: &mut SmallRng) -> Result<Sequential<f32>> {
14 let mut model = Sequential::new();
15
16 let input_dim = 2;
18 let hidden_dim = 4;
19 let output_dim = 1;
20
21 let dense1 = Dense::new(input_dim, hidden_dim, Some("relu"), rng)?;
23 model.add_layer(dense1);
24
25 let dropout = Dropout::new(0.1, rng)?;
27 model.add_layer(dropout);
28
29 let dense2 = Dense::new(hidden_dim, output_dim, Some("sigmoid"), rng)?;
31 model.add_layer(dense2);
32
33 Ok(model)
34}
35
36fn create_xor_dataset() -> (Array2<f32>, Array2<f32>) {
38 let x = Array2::from_shape_vec(
40 (4, 2),
41 vec![
42 0.0, 0.0, 0.0, 1.0, 1.0, 0.0, 1.0, 1.0, ],
47 )
48 .unwrap();
49
50 let y = Array2::from_shape_vec(
52 (4, 1),
53 vec![
54 0.0, 1.0, 1.0, 0.0, ],
59 )
60 .unwrap();
61
62 (x, y)
63}
64
65fn train_model(
67 model: &mut Sequential<f32>,
68 x: &Array2<f32>,
69 y: &Array2<f32>,
70 epochs: usize,
71) -> Result<()> {
72 println!("Training XOR model...");
73
74 let loss_fn = CrossEntropyLoss::new(1e-10);
76 let mut optimizer = Adam::new(0.01, 0.9, 0.999, 1e-8);
77
78 for epoch in 0..epochs {
80 let x_dyn = x.clone().into_dyn();
82 let y_dyn = y.clone().into_dyn();
83
84 let loss = model.train_batch(&x_dyn, &y_dyn, &loss_fn, &mut optimizer)?;
86
87 if epoch % 100 == 0 || epoch == epochs - 1 {
89 println!("Epoch {}/{}: loss = {:.6}", epoch + 1, epochs, loss);
90 }
91 }
92
93 println!("Training completed.");
94 Ok(())
95}
96
97fn evaluate_model(model: &Sequential<f32>, x: &Array2<f32>, y: &Array2<f32>) -> Result<f32> {
99 let predictions = model.forward(&x.clone().into_dyn())?;
100 let binary_thresh = 0.5;
101
102 println!("\nModel predictions:");
103 println!("-----------------");
104 println!(" X₁ | X₂ | Target | Prediction | Binary");
105 println!("----------------------------------------------");
106
107 let mut correct = 0;
108 for i in 0..x.shape()[0] {
109 let pred = predictions[[i, 0]];
110 let binary_pred = pred > binary_thresh;
111 let target = y[[i, 0]];
112 let is_correct = (binary_pred as i32 as f32 - target).abs() < 1e-6;
113
114 if is_correct {
115 correct += 1;
116 }
117
118 println!(
119 " {:.4} | {:.4} | {:.4} | {:.4} | {} {}",
120 x[[i, 0]],
121 x[[i, 1]],
122 target,
123 pred,
124 binary_pred as i32,
125 if is_correct { "✓" } else { "✗" }
126 );
127 }
128
129 let accuracy = correct as f32 / x.shape()[0] as f32;
130 println!(
131 "\nAccuracy: {:.2}% ({}/{})",
132 accuracy * 100.0,
133 correct,
134 x.shape()[0]
135 );
136
137 Ok(accuracy)
138}
139
140fn create_noisy_xor_dataset(
142 size: usize,
143 noise_level: f32,
144 rng: &mut SmallRng,
145) -> (Array2<f32>, Array2<f32>) {
146 let mut x = Array2::<f32>::zeros((size, 2));
147 let mut y = Array2::<f32>::zeros((size, 1));
148
149 for i in 0..size {
150 let x1 = (rng.random_range(0.0..1.0) > 0.5) as i32 as f32;
152 let x2 = (rng.random_range(0.0..1.0) > 0.5) as i32 as f32;
153
154 x[[i, 0]] = x1 + rng.random_range(-noise_level / 2.0..noise_level / 2.0);
156 x[[i, 1]] = x2 + rng.random_range(-noise_level / 2.0..noise_level / 2.0);
157
158 y[[i, 0]] = (x1 as i32 ^ x2 as i32) as f32;
160 }
161
162 (x, y)
163}
164
165fn main() -> Result<()> {
166 println!("Improved Model Serialization and Loading Example");
167 println!("===============================================\n");
168
169 let mut rng = SmallRng::seed_from_u64(42);
171
172 let (x_train, y_train) = create_xor_dataset();
174 println!("XOR dataset created");
175
176 let mut model = create_xor_model(&mut rng)?;
178 println!("Model created with {} layers", model.num_layers());
179
180 train_model(&mut model, &x_train, &y_train, 2000)?;
182
183 println!("\nEvaluating model before saving:");
185 evaluate_model(&model, &x_train, &y_train)?;
186
187 println!("\nSaving model in different formats...");
189
190 let json_path = Path::new("xor_model.json");
192 serialization::save_model(&model, json_path, SerializationFormat::JSON)?;
193 println!("Model saved to {} in JSON format", json_path.display());
194
195 let cbor_path = Path::new("xor_model.cbor");
197 serialization::save_model(&model, cbor_path, SerializationFormat::CBOR)?;
198 println!("Model saved to {} in CBOR format", cbor_path.display());
199
200 let msgpack_path = Path::new("xor_model.msgpack");
202 serialization::save_model(&model, msgpack_path, SerializationFormat::MessagePack)?;
203 println!(
204 "Model saved to {} in MessagePack format",
205 msgpack_path.display()
206 );
207
208 println!("\nLoading and evaluating models from each format:");
210
211 println!("\n--- JSON Format ---");
213 let json_model = serialization::load_model::<f32, _>(json_path, SerializationFormat::JSON)?;
214 println!("JSON model loaded with {} layers", json_model.num_layers());
215 evaluate_model(&json_model, &x_train, &y_train)?;
216
217 println!("\n--- CBOR Format ---");
219 let cbor_model = serialization::load_model::<f32, _>(cbor_path, SerializationFormat::CBOR)?;
220 println!("CBOR model loaded with {} layers", cbor_model.num_layers());
221 evaluate_model(&cbor_model, &x_train, &y_train)?;
222
223 println!("\n--- MessagePack Format ---");
225 let msgpack_model =
226 serialization::load_model::<f32, _>(msgpack_path, SerializationFormat::MessagePack)?;
227 println!(
228 "MessagePack model loaded with {} layers",
229 msgpack_model.num_layers()
230 );
231 evaluate_model(&msgpack_model, &x_train, &y_train)?;
232
233 println!("\nTesting with larger, noisy dataset:");
235 let (x_test, y_test) = create_noisy_xor_dataset(100, 0.2, &mut rng);
236 evaluate_model(&model, &x_test, &y_test)?;
237
238 let json_size = std::fs::metadata(json_path)?.len();
240 let cbor_size = std::fs::metadata(cbor_path)?.len();
241 let msgpack_size = std::fs::metadata(msgpack_path)?.len();
242
243 println!("\nSerialization Format Comparison:");
244 println!(" JSON: {} bytes", json_size);
245 println!(
246 " CBOR: {} bytes ({:.1}% of JSON)",
247 cbor_size,
248 (cbor_size as f64 / json_size as f64) * 100.0
249 );
250 println!(
251 " MessagePack: {} bytes ({:.1}% of JSON)",
252 msgpack_size,
253 (msgpack_size as f64 / json_size as f64) * 100.0
254 );
255
256 println!("\nModel serialization and loading example completed successfully!");
257 Ok(())
258}