scheduler_optimizer/
scheduler_optimizer.rs

1use ndarray::Array2;
2use rand::prelude::*;
3use rand::rngs::SmallRng;
4use scirs2_neural::callbacks::{CosineAnnealingLR, ScheduleMethod};
5use scirs2_neural::error::Result;
6use scirs2_neural::layers::Dense;
7use scirs2_neural::losses::MeanSquaredError;
8use scirs2_neural::models::{sequential::Sequential, Model};
9use scirs2_neural::optimizers::{with_cosine_annealing, with_step_decay, Adam, Optimizer};
10use std::time::Instant;
11
12// Create XOR dataset
13fn create_xor_dataset() -> (Array2<f32>, Array2<f32>) {
14    // XOR truth table inputs
15    let x = Array2::from_shape_vec(
16        (4, 2),
17        vec![
18            0.0, 0.0, // 0 XOR 0 = 0
19            0.0, 1.0, // 0 XOR 1 = 1
20            1.0, 0.0, // 1 XOR 0 = 1
21            1.0, 1.0, // 1 XOR 1 = 0
22        ],
23    )
24    .unwrap();
25
26    // XOR truth table outputs
27    let y = Array2::from_shape_vec(
28        (4, 1),
29        vec![
30            0.0, // 0 XOR 0 = 0
31            1.0, // 0 XOR 1 = 1
32            1.0, // 1 XOR 0 = 1
33            0.0, // 1 XOR 1 = 0
34        ],
35    )
36    .unwrap();
37
38    (x, y)
39}
40
41// Create a simple neural network model for the XOR problem
42fn create_xor_model(rng: &mut SmallRng) -> Result<Sequential<f32>> {
43    let mut model = Sequential::new();
44
45    // Input layer with 2 neurons (XOR has 2 inputs)
46    let dense1 = Dense::new(2, 8, Some("relu"), rng)?;
47    model.add_layer(dense1);
48
49    // Hidden layer
50    let dense2 = Dense::new(8, 4, Some("relu"), rng)?;
51    model.add_layer(dense2);
52
53    // Output layer with 1 neuron (XOR has 1 output)
54    let dense3 = Dense::new(4, 1, Some("sigmoid"), rng)?;
55    model.add_layer(dense3);
56
57    Ok(model)
58}
59
60// Evaluate model by printing predictions for the XOR problem
61fn evaluate_model(model: &Sequential<f32>, x: &Array2<f32>, y: &Array2<f32>) -> Result<f32> {
62    let predictions = model.forward(&x.clone().into_dyn())?;
63    let binary_thresh = 0.5;
64
65    println!("\nModel predictions:");
66    println!("-----------------");
67    println!("   X₁   |   X₂   | Target | Prediction | Binary");
68    println!("----------------------------------------------");
69
70    let mut correct = 0;
71    for i in 0..x.shape()[0] {
72        let pred = predictions[[i, 0]];
73        let binary_pred = pred > binary_thresh;
74        let target = y[[i, 0]];
75        let is_correct = (binary_pred as i32 as f32 - target).abs() < 1e-6;
76
77        if is_correct {
78            correct += 1;
79        }
80
81        println!(
82            " {:.4}  | {:.4}  | {:.4}  |   {:.4}   |  {}  {}",
83            x[[i, 0]],
84            x[[i, 1]],
85            target,
86            pred,
87            binary_pred as i32,
88            if is_correct { "✓" } else { "✗" }
89        );
90    }
91
92    let accuracy = correct as f32 / x.shape()[0] as f32;
93    println!(
94        "\nAccuracy: {:.2}% ({}/{})",
95        accuracy * 100.0,
96        correct,
97        x.shape()[0]
98    );
99
100    Ok(accuracy)
101}
102
103fn main() -> Result<()> {
104    println!("Learning Rate Scheduler Integration Example");
105    println!("===========================================\n");
106
107    // Initialize random number generator with a fixed seed for reproducibility
108    let mut rng = SmallRng::seed_from_u64(42);
109
110    // Create XOR dataset
111    let (x, y) = create_xor_dataset();
112    println!("Dataset created (XOR problem)");
113
114    // Train with different scheduler-optimizer integrations
115    train_with_step_decay(&mut rng, &x, &y)?;
116    train_with_cosine_annealing(&mut rng, &x, &y)?;
117    train_with_manual_scheduler_integration(&mut rng, &x, &y)?;
118
119    println!("\nAll training examples completed successfully!");
120    Ok(())
121}
122
123// Train with step decay learning rate scheduling
124fn train_with_step_decay(rng: &mut SmallRng, x: &Array2<f32>, y: &Array2<f32>) -> Result<()> {
125    println!("\n1. Training with Step Decay Learning Rate Scheduling");
126    println!("--------------------------------------------------");
127
128    let mut model = create_xor_model(rng)?;
129    println!("Created model with {} layers", model.num_layers());
130
131    // Setup loss function and optimizer with step decay scheduling
132    let loss_fn = MeanSquaredError::new();
133
134    // Option 1: Using the helper function
135    let epochs = 300;
136    let mut optimizer = with_step_decay(
137        Adam::new(0.1, 0.9, 0.999, 1e-8),
138        0.1,    // Initial LR
139        0.5,    // Factor (reduce by half)
140        50,     // Step size (every 50 epochs)
141        0.001,  // Min LR
142        epochs, // Total steps
143    );
144
145    println!("Starting training with step decay LR scheduling...");
146    println!("Initial LR: 0.1, Factor: 0.5, Step size: 50 epochs");
147    let start_time = Instant::now();
148
149    // Convert to dynamic arrays
150    let x_dyn = x.clone().into_dyn();
151    let y_dyn = y.clone().into_dyn();
152
153    // Training loop with learning rate tracking
154    let mut lr_history = Vec::<(usize, f32)>::new();
155
156    for epoch in 0..epochs {
157        // Train one batch
158        let loss = model.train_batch(&x_dyn, &y_dyn, &loss_fn, &mut optimizer)?;
159
160        // Record current learning rate
161        let current_lr = optimizer.get_learning_rate();
162
163        // Track learning rate changes
164        if epoch == 0 || lr_history.is_empty() || lr_history.last().unwrap().1 != current_lr {
165            lr_history.push((epoch, current_lr));
166        }
167
168        // Print progress
169        if epoch % 50 == 0 || epoch == epochs - 1 {
170            println!(
171                "Epoch {}/{}: loss = {:.6}, lr = {:.6}",
172                epoch + 1,
173                epochs,
174                loss,
175                current_lr
176            );
177        }
178    }
179
180    let elapsed = start_time.elapsed();
181    println!("Training completed in {:.2}s", elapsed.as_secs_f32());
182
183    // Print learning rate history
184    println!("\nLearning rate changes:");
185    for (epoch, lr) in lr_history {
186        println!("Epoch {}: lr = {:.6}", epoch + 1, lr);
187    }
188
189    // Evaluate the model
190    evaluate_model(&model, x, y)?;
191
192    Ok(())
193}
194
195// Train with cosine annealing learning rate scheduling
196fn train_with_cosine_annealing(rng: &mut SmallRng, x: &Array2<f32>, y: &Array2<f32>) -> Result<()> {
197    println!("\n2. Training with Cosine Annealing Learning Rate Scheduling");
198    println!("--------------------------------------------------------");
199
200    let mut model = create_xor_model(rng)?;
201    println!("Created model with {} layers", model.num_layers());
202
203    // Setup loss function and optimizer with cosine annealing scheduling
204    let loss_fn = MeanSquaredError::new();
205
206    // Using the helper function for cosine annealing
207    let epochs = 300;
208    let cycle_length = 50;
209    let mut optimizer = with_cosine_annealing(
210        Adam::new(0.01, 0.9, 0.999, 1e-8),
211        0.01,         // Max LR
212        0.0001,       // Min LR
213        cycle_length, // Cycle length
214        epochs,       // Total steps
215    );
216
217    println!("Starting training with cosine annealing LR scheduling...");
218    println!(
219        "Max LR: 0.01, Min LR: 0.0001, Cycle length: {} epochs",
220        cycle_length
221    );
222    let start_time = Instant::now();
223
224    // Convert to dynamic arrays
225    let x_dyn = x.clone().into_dyn();
226    let y_dyn = y.clone().into_dyn();
227
228    // Training loop with learning rate tracking
229    let mut lr_samples = Vec::<(usize, f32)>::new();
230
231    for epoch in 0..epochs {
232        // Train one batch
233        let loss = model.train_batch(&x_dyn, &y_dyn, &loss_fn, &mut optimizer)?;
234
235        // Get current learning rate
236        let current_lr = optimizer.get_learning_rate();
237
238        // Record learning rate at specific points to show the cycle
239        if epoch % 10 == 0 || epoch == epochs - 1 {
240            lr_samples.push((epoch, current_lr));
241        }
242
243        // Print progress
244        if epoch % 50 == 0 || epoch == epochs - 1 {
245            println!(
246                "Epoch {}/{}: loss = {:.6}, lr = {:.6}",
247                epoch + 1,
248                epochs,
249                loss,
250                current_lr
251            );
252        }
253    }
254
255    let elapsed = start_time.elapsed();
256    println!("Training completed in {:.2}s", elapsed.as_secs_f32());
257
258    // Print learning rate samples to demonstrate the cosine curve
259    println!("\nLearning rate samples (showing cosine curve):");
260    for (epoch, lr) in lr_samples {
261        println!("Epoch {}: lr = {:.6}", epoch + 1, lr);
262    }
263
264    // Evaluate the model
265    evaluate_model(&model, x, y)?;
266
267    Ok(())
268}
269
270// Train with manual scheduler integration
271fn train_with_manual_scheduler_integration(
272    rng: &mut SmallRng,
273    x: &Array2<f32>,
274    y: &Array2<f32>,
275) -> Result<()> {
276    println!("\n3. Training with Manual Scheduler Integration");
277    println!("-------------------------------------------");
278
279    let mut model = create_xor_model(rng)?;
280    println!("Created model with {} layers", model.num_layers());
281
282    // Setup loss function and optimizer
283    let loss_fn = MeanSquaredError::new();
284    let mut optimizer = Adam::new(0.01, 0.9, 0.999, 1e-8);
285
286    // Create scheduler manually
287    let epochs = 300;
288    let scheduler = CosineAnnealingLR::new(
289        0.01,   // Max LR
290        0.0001, // Min LR
291        100,    // Cycle length
292        ScheduleMethod::Epoch,
293        epochs, // Total steps
294    );
295
296    println!("Starting training with manual scheduler integration...");
297    println!("Max LR: 0.01, Min LR: 0.0001, Cycle length: 100 epochs");
298    let start_time = Instant::now();
299
300    // Convert to dynamic arrays
301    let x_dyn = x.clone().into_dyn();
302    let y_dyn = y.clone().into_dyn();
303
304    // Training loop with manual scheduler updates
305    for epoch in 0..epochs {
306        // Update learning rate using scheduler
307        let current_lr = scheduler.calculate_lr(epoch);
308        optimizer.set_learning_rate(current_lr);
309
310        // Train one batch
311        let loss = model.train_batch(&x_dyn, &y_dyn, &loss_fn, &mut optimizer)?;
312
313        // Print progress
314        if epoch % 50 == 0 || epoch == epochs - 1 {
315            println!(
316                "Epoch {}/{}: loss = {:.6}, lr = {:.6}",
317                epoch + 1,
318                epochs,
319                loss,
320                current_lr
321            );
322        }
323    }
324
325    let elapsed = start_time.elapsed();
326    println!("Training completed in {:.2}s", elapsed.as_secs_f32());
327
328    // Evaluate the model
329    evaluate_model(&model, x, y)?;
330
331    Ok(())
332}