advanced_callbacks/
advanced_callbacks.rs1use ndarray::{Array2, ScalarOperand};
2use num_traits::Float;
3use rand::prelude::*;
4use rand::rngs::SmallRng;
5use scirs2_neural::callbacks::{CallbackManager, EarlyStopping};
6use scirs2_neural::error::Result;
7use scirs2_neural::layers::Dense;
8use scirs2_neural::losses::MeanSquaredError;
9use scirs2_neural::models::{sequential::Sequential, Model};
10use scirs2_neural::optimizers::Adam;
11use std::collections::HashMap;
12use std::fmt::Debug;
13use std::time::Instant;
14
15fn create_sine_dataset(
17 n_samples: usize,
18 noise_level: f32,
19 rng: &mut SmallRng,
20) -> (Array2<f32>, Array2<f32>) {
21 let mut x = Array2::<f32>::zeros((n_samples, 1));
22 let mut y = Array2::<f32>::zeros((n_samples, 1));
23
24 for i in 0..n_samples {
25 let x_val = (i as f32) / (n_samples as f32) * 4.0 * std::f32::consts::PI;
26 let y_val = x_val.sin();
27
28 let noise = rng.random_range(-noise_level..noise_level);
30
31 x[[i, 0]] = x_val;
32 y[[i, 0]] = y_val + noise;
33 }
34
35 (x, y)
36}
37
38fn create_regression_model(input_dim: usize, rng: &mut SmallRng) -> Result<Sequential<f32>> {
40 let mut model = Sequential::new();
41
42 let dense1 = Dense::new(input_dim, 16, Some("relu"), rng)?;
44 model.add_layer(dense1);
45
46 let dense2 = Dense::new(16, 8, Some("relu"), rng)?;
48 model.add_layer(dense2);
49
50 let dense3 = Dense::new(8, 1, None, rng)?;
52 model.add_layer(dense3);
53
54 Ok(model)
55}
56
57fn calculate_mse<F: Float + Debug + ScalarOperand>(
59 model: &Sequential<F>,
60 x: &Array2<F>,
61 y: &Array2<F>,
62) -> Result<F> {
63 let predictions = model.forward(&x.clone().into_dyn())?;
64 let mut sum_squared_error = F::zero();
65
66 for i in 0..x.nrows() {
67 let diff = predictions[[i, 0]] - y[[i, 0]];
68 sum_squared_error = sum_squared_error + diff * diff;
69 }
70
71 Ok(sum_squared_error / F::from(x.nrows()).unwrap())
72}
73
74fn main() -> Result<()> {
75 println!("Advanced Learning Rate Scheduling and Early Stopping Example");
76 println!("==========================================================\n");
77
78 let mut rng = SmallRng::seed_from_u64(42);
80
81 let n_samples = 100;
83 let (x, y) = create_sine_dataset(n_samples, 0.1, &mut rng);
84 println!(
85 "Created synthetic sine wave regression dataset with {} samples",
86 n_samples
87 );
88
89 let train_size = (n_samples as f32 * 0.8) as usize;
91 let (x_train, y_train) = (
92 x.slice(ndarray::s![0..train_size, ..]).to_owned(),
93 y.slice(ndarray::s![0..train_size, ..]).to_owned(),
94 );
95
96 let (x_val, y_val) = (
97 x.slice(ndarray::s![train_size.., ..]).to_owned(),
98 y.slice(ndarray::s![train_size.., ..]).to_owned(),
99 );
100
101 println!(
102 "Split into {} training and {} validation samples",
103 x_train.nrows(),
104 x_val.nrows()
105 );
106
107 println!("\nTraining with early stopping...");
109
110 let model = train_with_early_stopping(&mut rng, &x_train, &y_train, &x_val, &y_val)?;
111
112 let val_mse = calculate_mse(&model, &x_val, &y_val)?;
114 println!("\nFinal validation MSE: {:.6}", val_mse);
115
116 println!("\nAdvanced callbacks example completed successfully!");
117 Ok(())
118}
119
120fn train_with_early_stopping(
122 rng: &mut SmallRng,
123 x_train: &Array2<f32>,
124 y_train: &Array2<f32>,
125 x_val: &Array2<f32>,
126 y_val: &Array2<f32>,
127) -> Result<Sequential<f32>> {
128 let mut model = create_regression_model(x_train.ncols(), rng)?;
129 println!("Created model with {} layers", model.num_layers());
130
131 let loss_fn = MeanSquaredError::new();
133 let mut optimizer = Adam::new(0.01, 0.9, 0.999, 1e-8);
134
135 let early_stopping = EarlyStopping::new(30, 0.0001, true);
138 let mut callback_manager = CallbackManager::<f32>::new();
139 callback_manager.add_callback(Box::new(early_stopping));
140
141 println!("Starting training with early stopping (patience = 30 epochs)...");
142 let start_time = Instant::now();
143
144 let x_train_dyn = x_train.clone().into_dyn();
146 let y_train_dyn = y_train.clone().into_dyn();
147
148 let max_epochs = 500;
150
151 let mut epoch_metrics = HashMap::new();
153 let mut best_val_loss = f32::MAX;
154 let mut stop_training = false;
155
156 for epoch in 0..max_epochs {
157 callback_manager.on_epoch_begin(epoch)?;
159
160 let train_loss = model.train_batch(&x_train_dyn, &y_train_dyn, &loss_fn, &mut optimizer)?;
162
163 let val_loss = calculate_mse(&model, x_val, y_val)?;
165
166 epoch_metrics.insert("loss".to_string(), train_loss);
168 epoch_metrics.insert("val_loss".to_string(), val_loss);
169
170 let should_stop = callback_manager.on_epoch_end(epoch, &epoch_metrics)?;
172
173 if should_stop {
174 println!("Early stopping triggered after {} epochs", epoch + 1);
175 stop_training = true;
176 }
177
178 if val_loss < best_val_loss {
180 best_val_loss = val_loss;
181 }
182
183 if epoch % 50 == 0 || epoch == max_epochs - 1 || stop_training {
185 println!(
186 "Epoch {}/{}: train_loss = {:.6}, val_loss = {:.6}",
187 epoch + 1,
188 max_epochs,
189 train_loss,
190 val_loss
191 );
192 }
193
194 if stop_training {
195 break;
196 }
197 }
198
199 let elapsed = start_time.elapsed();
200 println!("Training completed in {:.2}s", elapsed.as_secs_f32());
201 println!("Best validation MSE: {:.6}", best_val_loss);
202
203 Ok(model)
204}