advanced_callbacks/
advanced_callbacks.rs

1use ndarray::{Array2, ScalarOperand};
2use num_traits::Float;
3use rand::prelude::*;
4use rand::rngs::SmallRng;
5use scirs2_neural::callbacks::{CallbackManager, EarlyStopping};
6use scirs2_neural::error::Result;
7use scirs2_neural::layers::Dense;
8use scirs2_neural::losses::MeanSquaredError;
9use scirs2_neural::models::{sequential::Sequential, Model};
10use scirs2_neural::optimizers::Adam;
11use std::collections::HashMap;
12use std::fmt::Debug;
13use std::time::Instant;
14
15// Create a synthetic regression dataset (noisy sine wave)
16fn create_sine_dataset(
17    n_samples: usize,
18    noise_level: f32,
19    rng: &mut SmallRng,
20) -> (Array2<f32>, Array2<f32>) {
21    let mut x = Array2::<f32>::zeros((n_samples, 1));
22    let mut y = Array2::<f32>::zeros((n_samples, 1));
23
24    for i in 0..n_samples {
25        let x_val = (i as f32) / (n_samples as f32) * 4.0 * std::f32::consts::PI;
26        let y_val = x_val.sin();
27
28        // Add some noise
29        let noise = rng.random_range(-noise_level..noise_level);
30
31        x[[i, 0]] = x_val;
32        y[[i, 0]] = y_val + noise;
33    }
34
35    (x, y)
36}
37
38// Create a neural network model for regression
39fn create_regression_model(input_dim: usize, rng: &mut SmallRng) -> Result<Sequential<f32>> {
40    let mut model = Sequential::new();
41
42    // Input layer
43    let dense1 = Dense::new(input_dim, 16, Some("relu"), rng)?;
44    model.add_layer(dense1);
45
46    // Hidden layers
47    let dense2 = Dense::new(16, 8, Some("relu"), rng)?;
48    model.add_layer(dense2);
49
50    // Output layer (linear activation for regression)
51    let dense3 = Dense::new(8, 1, None, rng)?;
52    model.add_layer(dense3);
53
54    Ok(model)
55}
56
57// Calculate mean squared error
58fn calculate_mse<F: Float + Debug + ScalarOperand>(
59    model: &Sequential<F>,
60    x: &Array2<F>,
61    y: &Array2<F>,
62) -> Result<F> {
63    let predictions = model.forward(&x.clone().into_dyn())?;
64    let mut sum_squared_error = F::zero();
65
66    for i in 0..x.nrows() {
67        let diff = predictions[[i, 0]] - y[[i, 0]];
68        sum_squared_error = sum_squared_error + diff * diff;
69    }
70
71    Ok(sum_squared_error / F::from(x.nrows()).unwrap())
72}
73
74fn main() -> Result<()> {
75    println!("Advanced Learning Rate Scheduling and Early Stopping Example");
76    println!("==========================================================\n");
77
78    // Initialize random number generator
79    let mut rng = SmallRng::seed_from_u64(42);
80
81    // Create synthetic regression dataset
82    let n_samples = 100;
83    let (x, y) = create_sine_dataset(n_samples, 0.1, &mut rng);
84    println!(
85        "Created synthetic sine wave regression dataset with {} samples",
86        n_samples
87    );
88
89    // Generate 80% training data, 20% validation data
90    let train_size = (n_samples as f32 * 0.8) as usize;
91    let (x_train, y_train) = (
92        x.slice(ndarray::s![0..train_size, ..]).to_owned(),
93        y.slice(ndarray::s![0..train_size, ..]).to_owned(),
94    );
95
96    let (x_val, y_val) = (
97        x.slice(ndarray::s![train_size.., ..]).to_owned(),
98        y.slice(ndarray::s![train_size.., ..]).to_owned(),
99    );
100
101    println!(
102        "Split into {} training and {} validation samples",
103        x_train.nrows(),
104        x_val.nrows()
105    );
106
107    // Train with early stopping
108    println!("\nTraining with early stopping...");
109
110    let model = train_with_early_stopping(&mut rng, &x_train, &y_train, &x_val, &y_val)?;
111
112    // Evaluate final validation loss
113    let val_mse = calculate_mse(&model, &x_val, &y_val)?;
114    println!("\nFinal validation MSE: {:.6}", val_mse);
115
116    println!("\nAdvanced callbacks example completed successfully!");
117    Ok(())
118}
119
120// Train with early stopping and validation
121fn train_with_early_stopping(
122    rng: &mut SmallRng,
123    x_train: &Array2<f32>,
124    y_train: &Array2<f32>,
125    x_val: &Array2<f32>,
126    y_val: &Array2<f32>,
127) -> Result<Sequential<f32>> {
128    let mut model = create_regression_model(x_train.ncols(), rng)?;
129    println!("Created model with {} layers", model.num_layers());
130
131    // Setup loss function and optimizer
132    let loss_fn = MeanSquaredError::new();
133    let mut optimizer = Adam::new(0.01, 0.9, 0.999, 1e-8);
134
135    // Setup early stopping callback
136    // Stop training if validation loss doesn't improve for 30 epochs
137    let early_stopping = EarlyStopping::new(30, 0.0001, true);
138    let mut callback_manager = CallbackManager::<f32>::new();
139    callback_manager.add_callback(Box::new(early_stopping));
140
141    println!("Starting training with early stopping (patience = 30 epochs)...");
142    let start_time = Instant::now();
143
144    // Convert to dynamic arrays
145    let x_train_dyn = x_train.clone().into_dyn();
146    let y_train_dyn = y_train.clone().into_dyn();
147
148    // Set up training parameters
149    let max_epochs = 500;
150
151    // Training loop with validation
152    let mut epoch_metrics = HashMap::new();
153    let mut best_val_loss = f32::MAX;
154    let mut stop_training = false;
155
156    for epoch in 0..max_epochs {
157        // Call callbacks before epoch
158        callback_manager.on_epoch_begin(epoch)?;
159
160        // Train one epoch
161        let train_loss = model.train_batch(&x_train_dyn, &y_train_dyn, &loss_fn, &mut optimizer)?;
162
163        // Validate
164        let val_loss = calculate_mse(&model, x_val, y_val)?;
165
166        // Update metrics
167        epoch_metrics.insert("loss".to_string(), train_loss);
168        epoch_metrics.insert("val_loss".to_string(), val_loss);
169
170        // Call callbacks after epoch
171        let should_stop = callback_manager.on_epoch_end(epoch, &epoch_metrics)?;
172
173        if should_stop {
174            println!("Early stopping triggered after {} epochs", epoch + 1);
175            stop_training = true;
176        }
177
178        // Track best validation loss
179        if val_loss < best_val_loss {
180            best_val_loss = val_loss;
181        }
182
183        // Print progress
184        if epoch % 50 == 0 || epoch == max_epochs - 1 || stop_training {
185            println!(
186                "Epoch {}/{}: train_loss = {:.6}, val_loss = {:.6}",
187                epoch + 1,
188                max_epochs,
189                train_loss,
190                val_loss
191            );
192        }
193
194        if stop_training {
195            break;
196        }
197    }
198
199    let elapsed = start_time.elapsed();
200    println!("Training completed in {:.2}s", elapsed.as_secs_f32());
201    println!("Best validation MSE: {:.6}", best_val_loss);
202
203    Ok(model)
204}