Function load_model

Source
pub fn load_model<F: Float + Debug + ScalarOperand + Send + Sync + 'static, P: AsRef<Path>>(
    path: P,
    format: SerializationFormat,
) -> Result<Sequential<F>>
Expand description

Load model from file

Examples found in repository?
examples/model_serialization_example.rs (line 72)
9fn main() -> Result<(), Box<dyn std::error::Error>> {
10    println!("Model Serialization Example");
11
12    // Initialize random number generator
13    let mut rng = SmallRng::seed_from_u64(42);
14
15    // 1. Create a simple neural network model
16    let mut model = Sequential::new();
17
18    // Add layers
19    let input_dim = 784; // MNIST image size: 28x28 = 784
20    let hidden_dim_1 = 256;
21    let hidden_dim_2 = 128;
22    let output_dim = 10; // 10 classes for digits 0-9
23
24    // Input layer to first hidden layer
25    let dense1 = Dense::new(input_dim, hidden_dim_1, Some("relu"), &mut rng)?;
26    model.add_layer(dense1);
27
28    // Dropout for regularization
29    let dropout1 = Dropout::new(0.2, &mut rng)?;
30    model.add_layer(dropout1);
31
32    // First hidden layer to second hidden layer
33    let dense2 = Dense::new(hidden_dim_1, hidden_dim_2, Some("relu"), &mut rng)?;
34    model.add_layer(dense2);
35
36    // Layer normalization
37    let layer_norm = LayerNorm::new(hidden_dim_2, 1e-5, &mut rng)?;
38    model.add_layer(layer_norm);
39
40    // Second hidden layer to output layer
41    let dense3 = Dense::new(hidden_dim_2, output_dim, Some("softmax"), &mut rng)?;
42    model.add_layer(dense3);
43
44    println!(
45        "Created a neural network with {} layers",
46        model.num_layers()
47    );
48
49    // 2. Test the model with some dummy input
50    let batch_size = 2;
51    let input = Array2::<f32>::from_elem((batch_size, input_dim), 0.1);
52    let output = model.forward(&input.clone().into_dyn())?;
53
54    println!("Model output shape: {:?}", output.shape());
55    println!("First few output values:");
56    for i in 0..batch_size {
57        print!("Sample {}: [ ", i);
58        for j in 0..5 {
59            // Print first 5 values
60            print!("{:.6} ", output[[i, j]]);
61        }
62        println!("... ]");
63    }
64
65    // 3. Save the model to a file
66    let model_path = Path::new("mnist_model.json");
67    serialization::save_model(&model, model_path, SerializationFormat::JSON)?;
68
69    println!("\nModel saved to {}", model_path.display());
70
71    // 4. Load the model from the file
72    let loaded_model = serialization::load_model::<f32, _>(model_path, SerializationFormat::JSON)?;
73
74    println!(
75        "Model loaded successfully with {} layers",
76        loaded_model.num_layers()
77    );
78
79    // 5. Test the loaded model with the same input
80    let loaded_output = loaded_model.forward(&input.into_dyn())?;
81
82    println!("\nLoaded model output shape: {:?}", loaded_output.shape());
83    println!("First few output values:");
84    for i in 0..batch_size {
85        print!("Sample {}: [ ", i);
86        for j in 0..5 {
87            // Print first 5 values
88            print!("{:.6} ", loaded_output[[i, j]]);
89        }
90        println!("... ]");
91    }
92
93    // 6. Compare original and loaded model outputs
94    let mut max_diff = 0.0;
95    for i in 0..batch_size {
96        for j in 0..output_dim {
97            let diff = (output[[i, j]] - loaded_output[[i, j]]).abs();
98            if diff > max_diff {
99                max_diff = diff;
100            }
101        }
102    }
103
104    println!(
105        "\nMaximum difference between original and loaded model outputs: {:.6}",
106        max_diff
107    );
108
109    if max_diff < 1e-6 {
110        println!("Models are identical! Serialization and deserialization worked correctly.");
111    } else {
112        println!("Warning: There are differences between the original and loaded models.");
113        println!(
114            "This might be due to numerical precision issues or a problem with serialization."
115        );
116    }
117
118    Ok(())
119}
More examples
Hide additional examples
examples/improved_model_serialization.rs (line 213)
165fn main() -> Result<()> {
166    println!("Improved Model Serialization and Loading Example");
167    println!("===============================================\n");
168
169    // Initialize random number generator
170    let mut rng = SmallRng::seed_from_u64(42);
171
172    // 1. Create XOR datasets
173    let (x_train, y_train) = create_xor_dataset();
174    println!("XOR dataset created");
175
176    // 2. Create and train the model
177    let mut model = create_xor_model(&mut rng)?;
178    println!("Model created with {} layers", model.num_layers());
179
180    // Train the model
181    train_model(&mut model, &x_train, &y_train, 2000)?;
182
183    // 3. Evaluate the model before saving
184    println!("\nEvaluating model before saving:");
185    evaluate_model(&model, &x_train, &y_train)?;
186
187    // 4. Save the model in different formats
188    println!("\nSaving model in different formats...");
189
190    // Save in JSON format (human-readable)
191    let json_path = Path::new("xor_model.json");
192    serialization::save_model(&model, json_path, SerializationFormat::JSON)?;
193    println!("Model saved to {} in JSON format", json_path.display());
194
195    // Save in CBOR format (compact binary)
196    let cbor_path = Path::new("xor_model.cbor");
197    serialization::save_model(&model, cbor_path, SerializationFormat::CBOR)?;
198    println!("Model saved to {} in CBOR format", cbor_path.display());
199
200    // Save in MessagePack format (efficient binary)
201    let msgpack_path = Path::new("xor_model.msgpack");
202    serialization::save_model(&model, msgpack_path, SerializationFormat::MessagePack)?;
203    println!(
204        "Model saved to {} in MessagePack format",
205        msgpack_path.display()
206    );
207
208    // 5. Load models from each format and evaluate
209    println!("\nLoading and evaluating models from each format:");
210
211    // Load and evaluate JSON model
212    println!("\n--- JSON Format ---");
213    let json_model = serialization::load_model::<f32, _>(json_path, SerializationFormat::JSON)?;
214    println!("JSON model loaded with {} layers", json_model.num_layers());
215    evaluate_model(&json_model, &x_train, &y_train)?;
216
217    // Load and evaluate CBOR model
218    println!("\n--- CBOR Format ---");
219    let cbor_model = serialization::load_model::<f32, _>(cbor_path, SerializationFormat::CBOR)?;
220    println!("CBOR model loaded with {} layers", cbor_model.num_layers());
221    evaluate_model(&cbor_model, &x_train, &y_train)?;
222
223    // Load and evaluate MessagePack model
224    println!("\n--- MessagePack Format ---");
225    let msgpack_model =
226        serialization::load_model::<f32, _>(msgpack_path, SerializationFormat::MessagePack)?;
227    println!(
228        "MessagePack model loaded with {} layers",
229        msgpack_model.num_layers()
230    );
231    evaluate_model(&msgpack_model, &x_train, &y_train)?;
232
233    // 6. Test with a larger, noisy dataset to verify model works with unseen data
234    println!("\nTesting with larger, noisy dataset:");
235    let (x_test, y_test) = create_noisy_xor_dataset(100, 0.2, &mut rng);
236    evaluate_model(&model, &x_test, &y_test)?;
237
238    // File sizes for comparison
239    let json_size = std::fs::metadata(json_path)?.len();
240    let cbor_size = std::fs::metadata(cbor_path)?.len();
241    let msgpack_size = std::fs::metadata(msgpack_path)?.len();
242
243    println!("\nSerialization Format Comparison:");
244    println!("  JSON:       {} bytes", json_size);
245    println!(
246        "  CBOR:       {} bytes ({:.1}% of JSON)",
247        cbor_size,
248        (cbor_size as f64 / json_size as f64) * 100.0
249    );
250    println!(
251        "  MessagePack: {} bytes ({:.1}% of JSON)",
252        msgpack_size,
253        (msgpack_size as f64 / json_size as f64) * 100.0
254    );
255
256    println!("\nModel serialization and loading example completed successfully!");
257    Ok(())
258}