pub fn load_model<F: Float + Debug + ScalarOperand + Send + Sync + 'static, P: AsRef<Path>>(
path: P,
format: SerializationFormat,
) -> Result<Sequential<F>>
Expand description
Load model from file
Examples found in repository?
examples/model_serialization_example.rs (line 72)
9fn main() -> Result<(), Box<dyn std::error::Error>> {
10 println!("Model Serialization Example");
11
12 // Initialize random number generator
13 let mut rng = SmallRng::seed_from_u64(42);
14
15 // 1. Create a simple neural network model
16 let mut model = Sequential::new();
17
18 // Add layers
19 let input_dim = 784; // MNIST image size: 28x28 = 784
20 let hidden_dim_1 = 256;
21 let hidden_dim_2 = 128;
22 let output_dim = 10; // 10 classes for digits 0-9
23
24 // Input layer to first hidden layer
25 let dense1 = Dense::new(input_dim, hidden_dim_1, Some("relu"), &mut rng)?;
26 model.add_layer(dense1);
27
28 // Dropout for regularization
29 let dropout1 = Dropout::new(0.2, &mut rng)?;
30 model.add_layer(dropout1);
31
32 // First hidden layer to second hidden layer
33 let dense2 = Dense::new(hidden_dim_1, hidden_dim_2, Some("relu"), &mut rng)?;
34 model.add_layer(dense2);
35
36 // Layer normalization
37 let layer_norm = LayerNorm::new(hidden_dim_2, 1e-5, &mut rng)?;
38 model.add_layer(layer_norm);
39
40 // Second hidden layer to output layer
41 let dense3 = Dense::new(hidden_dim_2, output_dim, Some("softmax"), &mut rng)?;
42 model.add_layer(dense3);
43
44 println!(
45 "Created a neural network with {} layers",
46 model.num_layers()
47 );
48
49 // 2. Test the model with some dummy input
50 let batch_size = 2;
51 let input = Array2::<f32>::from_elem((batch_size, input_dim), 0.1);
52 let output = model.forward(&input.clone().into_dyn())?;
53
54 println!("Model output shape: {:?}", output.shape());
55 println!("First few output values:");
56 for i in 0..batch_size {
57 print!("Sample {}: [ ", i);
58 for j in 0..5 {
59 // Print first 5 values
60 print!("{:.6} ", output[[i, j]]);
61 }
62 println!("... ]");
63 }
64
65 // 3. Save the model to a file
66 let model_path = Path::new("mnist_model.json");
67 serialization::save_model(&model, model_path, SerializationFormat::JSON)?;
68
69 println!("\nModel saved to {}", model_path.display());
70
71 // 4. Load the model from the file
72 let loaded_model = serialization::load_model::<f32, _>(model_path, SerializationFormat::JSON)?;
73
74 println!(
75 "Model loaded successfully with {} layers",
76 loaded_model.num_layers()
77 );
78
79 // 5. Test the loaded model with the same input
80 let loaded_output = loaded_model.forward(&input.into_dyn())?;
81
82 println!("\nLoaded model output shape: {:?}", loaded_output.shape());
83 println!("First few output values:");
84 for i in 0..batch_size {
85 print!("Sample {}: [ ", i);
86 for j in 0..5 {
87 // Print first 5 values
88 print!("{:.6} ", loaded_output[[i, j]]);
89 }
90 println!("... ]");
91 }
92
93 // 6. Compare original and loaded model outputs
94 let mut max_diff = 0.0;
95 for i in 0..batch_size {
96 for j in 0..output_dim {
97 let diff = (output[[i, j]] - loaded_output[[i, j]]).abs();
98 if diff > max_diff {
99 max_diff = diff;
100 }
101 }
102 }
103
104 println!(
105 "\nMaximum difference between original and loaded model outputs: {:.6}",
106 max_diff
107 );
108
109 if max_diff < 1e-6 {
110 println!("Models are identical! Serialization and deserialization worked correctly.");
111 } else {
112 println!("Warning: There are differences between the original and loaded models.");
113 println!(
114 "This might be due to numerical precision issues or a problem with serialization."
115 );
116 }
117
118 Ok(())
119}
More examples
examples/improved_model_serialization.rs (line 213)
165fn main() -> Result<()> {
166 println!("Improved Model Serialization and Loading Example");
167 println!("===============================================\n");
168
169 // Initialize random number generator
170 let mut rng = SmallRng::seed_from_u64(42);
171
172 // 1. Create XOR datasets
173 let (x_train, y_train) = create_xor_dataset();
174 println!("XOR dataset created");
175
176 // 2. Create and train the model
177 let mut model = create_xor_model(&mut rng)?;
178 println!("Model created with {} layers", model.num_layers());
179
180 // Train the model
181 train_model(&mut model, &x_train, &y_train, 2000)?;
182
183 // 3. Evaluate the model before saving
184 println!("\nEvaluating model before saving:");
185 evaluate_model(&model, &x_train, &y_train)?;
186
187 // 4. Save the model in different formats
188 println!("\nSaving model in different formats...");
189
190 // Save in JSON format (human-readable)
191 let json_path = Path::new("xor_model.json");
192 serialization::save_model(&model, json_path, SerializationFormat::JSON)?;
193 println!("Model saved to {} in JSON format", json_path.display());
194
195 // Save in CBOR format (compact binary)
196 let cbor_path = Path::new("xor_model.cbor");
197 serialization::save_model(&model, cbor_path, SerializationFormat::CBOR)?;
198 println!("Model saved to {} in CBOR format", cbor_path.display());
199
200 // Save in MessagePack format (efficient binary)
201 let msgpack_path = Path::new("xor_model.msgpack");
202 serialization::save_model(&model, msgpack_path, SerializationFormat::MessagePack)?;
203 println!(
204 "Model saved to {} in MessagePack format",
205 msgpack_path.display()
206 );
207
208 // 5. Load models from each format and evaluate
209 println!("\nLoading and evaluating models from each format:");
210
211 // Load and evaluate JSON model
212 println!("\n--- JSON Format ---");
213 let json_model = serialization::load_model::<f32, _>(json_path, SerializationFormat::JSON)?;
214 println!("JSON model loaded with {} layers", json_model.num_layers());
215 evaluate_model(&json_model, &x_train, &y_train)?;
216
217 // Load and evaluate CBOR model
218 println!("\n--- CBOR Format ---");
219 let cbor_model = serialization::load_model::<f32, _>(cbor_path, SerializationFormat::CBOR)?;
220 println!("CBOR model loaded with {} layers", cbor_model.num_layers());
221 evaluate_model(&cbor_model, &x_train, &y_train)?;
222
223 // Load and evaluate MessagePack model
224 println!("\n--- MessagePack Format ---");
225 let msgpack_model =
226 serialization::load_model::<f32, _>(msgpack_path, SerializationFormat::MessagePack)?;
227 println!(
228 "MessagePack model loaded with {} layers",
229 msgpack_model.num_layers()
230 );
231 evaluate_model(&msgpack_model, &x_train, &y_train)?;
232
233 // 6. Test with a larger, noisy dataset to verify model works with unseen data
234 println!("\nTesting with larger, noisy dataset:");
235 let (x_test, y_test) = create_noisy_xor_dataset(100, 0.2, &mut rng);
236 evaluate_model(&model, &x_test, &y_test)?;
237
238 // File sizes for comparison
239 let json_size = std::fs::metadata(json_path)?.len();
240 let cbor_size = std::fs::metadata(cbor_path)?.len();
241 let msgpack_size = std::fs::metadata(msgpack_path)?.len();
242
243 println!("\nSerialization Format Comparison:");
244 println!(" JSON: {} bytes", json_size);
245 println!(
246 " CBOR: {} bytes ({:.1}% of JSON)",
247 cbor_size,
248 (cbor_size as f64 / json_size as f64) * 100.0
249 );
250 println!(
251 " MessagePack: {} bytes ({:.1}% of JSON)",
252 msgpack_size,
253 (msgpack_size as f64 / json_size as f64) * 100.0
254 );
255
256 println!("\nModel serialization and loading example completed successfully!");
257 Ok(())
258}