Struct StepDecay

Source
pub struct StepDecay<F: Float + Debug + ScalarOperand> { /* private fields */ }
Expand description

Step decay learning rate scheduler

Reduces the learning rate by a factor every n steps (epochs or batches).

Implementations§

Source§

impl<F: Float + Debug + ScalarOperand> StepDecay<F>

Source

pub fn new( initial_lr: F, factor: F, step_size: usize, method: ScheduleMethod, min_lr: F, ) -> Self

Create a new step decay scheduler

§Arguments
  • initial_lr - Initial learning rate
  • factor - Factor to multiply learning rate by (should be less than 1.0)
  • step_size - Number of steps between reductions
  • method - Whether to schedule on epoch or batch
  • min_lr - Minimum learning rate
Examples found in repository?
examples/visualize_training_progress.rs (lines 115-121)
83fn main() -> Result<()> {
84    println!("Training Visualization Example");
85    println!("==============================\n");
86
87    // Initialize RNG with a fixed seed for reproducibility
88    let mut rng = SmallRng::seed_from_u64(42);
89
90    // Generate synthetic data
91    let num_samples = 200;
92    let (x, y) = generate_nonlinear_data(num_samples, &mut rng);
93    println!("Generated synthetic dataset with {} samples", num_samples);
94
95    // Split data into training and validation sets
96    let (x_train, y_train, x_val, y_val) = train_val_split(&x, &y, 0.8);
97    println!(
98        "Split data: {} training samples, {} validation samples",
99        x_train.shape()[0],
100        x_val.shape()[0]
101    );
102
103    // Create model
104    let mut model = create_regression_model(1, &mut rng)?;
105    println!("Created model with {} layers", model.num_layers());
106
107    // Setup loss function and optimizer
108    let loss_fn = MeanSquaredError::new();
109    let mut optimizer = Adam::new(0.01, 0.9, 0.999, 1e-8);
110
111    // Training parameters
112    let epochs = 100;
113
114    // Configure learning rate scheduler
115    let mut scheduler = StepDecay::new(
116        0.01, // Initial learning rate
117        0.5,  // Decay factor
118        30,   // Step size
119        ScheduleMethod::Epoch,
120        1e-4, // Min learning rate
121    );
122
123    // Add visualization callback
124    let mut visualization_cb =
125        VisualizationCallback::new(5) // Show every 5 epochs
126            .with_tracked_metrics(vec!["train_loss".to_string(), "val_loss".to_string()])
127            .with_plot_options(PlotOptions {
128                width: 80,
129                height: 15,
130                max_x_ticks: 10,
131                max_y_ticks: 5,
132                line_char: '─',
133                point_char: '●',
134                background_char: ' ',
135                show_grid: true,
136                show_legend: true,
137            })
138            .with_save_path("training_plot.txt");
139
140    // Train the model manually
141    let mut epoch_history = HashMap::new();
142    epoch_history.insert("train_loss".to_string(), Vec::new());
143    epoch_history.insert("val_loss".to_string(), Vec::new());
144
145    // Convert data to dynamic arrays and ensure they are owned (not views)
146    let x_train_dyn = x_train.clone().into_dyn();
147    let y_train_dyn = y_train.clone().into_dyn();
148    let x_val_dyn = x_val.clone().into_dyn();
149    let y_val_dyn = y_val.clone().into_dyn();
150
151    println!("\nStarting training with visualization...");
152
153    // Manual training loop
154    println!("\nStarting training loop...");
155    for epoch in 0..epochs {
156        // Update learning rate with scheduler
157        let current_lr = scheduler.get_lr();
158        optimizer.set_learning_rate(current_lr);
159
160        // Train for one epoch (batch size = full dataset in this example)
161        let train_loss = model.train_batch(&x_train_dyn, &y_train_dyn, &loss_fn, &mut optimizer)?;
162
163        // Compute validation loss
164        let predictions = model.forward(&x_val_dyn)?;
165        let val_loss = loss_fn.forward(&predictions, &y_val_dyn)?;
166
167        // Store metrics
168        epoch_history
169            .get_mut("train_loss")
170            .unwrap()
171            .push(train_loss);
172        epoch_history.get_mut("val_loss").unwrap().push(val_loss);
173
174        // Update the scheduler
175        scheduler.update_lr(epoch);
176
177        // Print progress
178        if epoch % 10 == 0 || epoch == epochs - 1 {
179            println!(
180                "Epoch {}/{}: train_loss = {:.6}, val_loss = {:.6}, lr = {:.6}",
181                epoch + 1,
182                epochs,
183                train_loss,
184                val_loss,
185                current_lr
186            );
187        }
188
189        // Visualize progress (manually calling the visualization callback)
190        if epoch % 5 == 0 || epoch == epochs - 1 {
191            let mut context = CallbackContext {
192                epoch,
193                total_epochs: epochs,
194                batch: 0,
195                total_batches: 1,
196                batch_loss: None,
197                epoch_loss: Some(train_loss),
198                val_loss: Some(val_loss),
199                metrics: vec![],
200                history: &epoch_history,
201                stop_training: false,
202                model: None,
203            };
204
205            visualization_cb.on_event(CallbackTiming::AfterEpoch, &mut context)?;
206        }
207    }
208
209    // Final visualization
210    let mut context = CallbackContext {
211        epoch: epochs - 1,
212        total_epochs: epochs,
213        batch: 0,
214        total_batches: 1,
215        batch_loss: None,
216        epoch_loss: Some(*epoch_history.get("train_loss").unwrap().last().unwrap()),
217        val_loss: Some(*epoch_history.get("val_loss").unwrap().last().unwrap()),
218        metrics: vec![],
219        history: &epoch_history,
220        stop_training: false,
221        model: None,
222    };
223
224    visualization_cb.on_event(CallbackTiming::AfterTraining, &mut context)?;
225
226    println!("\nTraining complete!");
227
228    // Export history to CSV
229    export_history_to_csv(&epoch_history, "training_history.csv")?;
230    println!("Exported training history to training_history.csv");
231
232    // Analyze training results
233    let analysis = analyze_training_history(&epoch_history);
234    println!("\nTraining Analysis:");
235    println!("------------------");
236    for issue in analysis {
237        println!("{}", issue);
238    }
239
240    // Make predictions on validation data
241    println!("\nMaking predictions on validation data...");
242    let predictions = model.forward(&x_val_dyn)?;
243
244    // Calculate and display final metrics
245    let mse = loss_fn.forward(&predictions, &y_val_dyn)?;
246    println!("Final validation MSE: {:.6}", mse);
247
248    // Display a few sample predictions
249    println!("\nSample predictions:");
250    println!("------------------");
251    println!("  X  |  True Y  | Predicted Y ");
252    println!("---------------------------");
253
254    // Show first 5 predictions
255    let num_samples_to_show = std::cmp::min(5, x_val.shape()[0]);
256    for i in 0..num_samples_to_show {
257        println!(
258            "{:.4} | {:.4}   | {:.4}",
259            x_val[[i, 0]],
260            y_val[[i, 0]],
261            predictions[[i, 0]]
262        );
263    }
264
265    println!("\nVisualization demonstration complete!");
266    Ok(())
267}
Source

pub fn get_initial_lr(&self) -> F

Get the initial learning rate

Source

pub fn get_lr(&self) -> F

Get the current learning rate

Examples found in repository?
examples/visualize_training_progress.rs (line 157)
83fn main() -> Result<()> {
84    println!("Training Visualization Example");
85    println!("==============================\n");
86
87    // Initialize RNG with a fixed seed for reproducibility
88    let mut rng = SmallRng::seed_from_u64(42);
89
90    // Generate synthetic data
91    let num_samples = 200;
92    let (x, y) = generate_nonlinear_data(num_samples, &mut rng);
93    println!("Generated synthetic dataset with {} samples", num_samples);
94
95    // Split data into training and validation sets
96    let (x_train, y_train, x_val, y_val) = train_val_split(&x, &y, 0.8);
97    println!(
98        "Split data: {} training samples, {} validation samples",
99        x_train.shape()[0],
100        x_val.shape()[0]
101    );
102
103    // Create model
104    let mut model = create_regression_model(1, &mut rng)?;
105    println!("Created model with {} layers", model.num_layers());
106
107    // Setup loss function and optimizer
108    let loss_fn = MeanSquaredError::new();
109    let mut optimizer = Adam::new(0.01, 0.9, 0.999, 1e-8);
110
111    // Training parameters
112    let epochs = 100;
113
114    // Configure learning rate scheduler
115    let mut scheduler = StepDecay::new(
116        0.01, // Initial learning rate
117        0.5,  // Decay factor
118        30,   // Step size
119        ScheduleMethod::Epoch,
120        1e-4, // Min learning rate
121    );
122
123    // Add visualization callback
124    let mut visualization_cb =
125        VisualizationCallback::new(5) // Show every 5 epochs
126            .with_tracked_metrics(vec!["train_loss".to_string(), "val_loss".to_string()])
127            .with_plot_options(PlotOptions {
128                width: 80,
129                height: 15,
130                max_x_ticks: 10,
131                max_y_ticks: 5,
132                line_char: '─',
133                point_char: '●',
134                background_char: ' ',
135                show_grid: true,
136                show_legend: true,
137            })
138            .with_save_path("training_plot.txt");
139
140    // Train the model manually
141    let mut epoch_history = HashMap::new();
142    epoch_history.insert("train_loss".to_string(), Vec::new());
143    epoch_history.insert("val_loss".to_string(), Vec::new());
144
145    // Convert data to dynamic arrays and ensure they are owned (not views)
146    let x_train_dyn = x_train.clone().into_dyn();
147    let y_train_dyn = y_train.clone().into_dyn();
148    let x_val_dyn = x_val.clone().into_dyn();
149    let y_val_dyn = y_val.clone().into_dyn();
150
151    println!("\nStarting training with visualization...");
152
153    // Manual training loop
154    println!("\nStarting training loop...");
155    for epoch in 0..epochs {
156        // Update learning rate with scheduler
157        let current_lr = scheduler.get_lr();
158        optimizer.set_learning_rate(current_lr);
159
160        // Train for one epoch (batch size = full dataset in this example)
161        let train_loss = model.train_batch(&x_train_dyn, &y_train_dyn, &loss_fn, &mut optimizer)?;
162
163        // Compute validation loss
164        let predictions = model.forward(&x_val_dyn)?;
165        let val_loss = loss_fn.forward(&predictions, &y_val_dyn)?;
166
167        // Store metrics
168        epoch_history
169            .get_mut("train_loss")
170            .unwrap()
171            .push(train_loss);
172        epoch_history.get_mut("val_loss").unwrap().push(val_loss);
173
174        // Update the scheduler
175        scheduler.update_lr(epoch);
176
177        // Print progress
178        if epoch % 10 == 0 || epoch == epochs - 1 {
179            println!(
180                "Epoch {}/{}: train_loss = {:.6}, val_loss = {:.6}, lr = {:.6}",
181                epoch + 1,
182                epochs,
183                train_loss,
184                val_loss,
185                current_lr
186            );
187        }
188
189        // Visualize progress (manually calling the visualization callback)
190        if epoch % 5 == 0 || epoch == epochs - 1 {
191            let mut context = CallbackContext {
192                epoch,
193                total_epochs: epochs,
194                batch: 0,
195                total_batches: 1,
196                batch_loss: None,
197                epoch_loss: Some(train_loss),
198                val_loss: Some(val_loss),
199                metrics: vec![],
200                history: &epoch_history,
201                stop_training: false,
202                model: None,
203            };
204
205            visualization_cb.on_event(CallbackTiming::AfterEpoch, &mut context)?;
206        }
207    }
208
209    // Final visualization
210    let mut context = CallbackContext {
211        epoch: epochs - 1,
212        total_epochs: epochs,
213        batch: 0,
214        total_batches: 1,
215        batch_loss: None,
216        epoch_loss: Some(*epoch_history.get("train_loss").unwrap().last().unwrap()),
217        val_loss: Some(*epoch_history.get("val_loss").unwrap().last().unwrap()),
218        metrics: vec![],
219        history: &epoch_history,
220        stop_training: false,
221        model: None,
222    };
223
224    visualization_cb.on_event(CallbackTiming::AfterTraining, &mut context)?;
225
226    println!("\nTraining complete!");
227
228    // Export history to CSV
229    export_history_to_csv(&epoch_history, "training_history.csv")?;
230    println!("Exported training history to training_history.csv");
231
232    // Analyze training results
233    let analysis = analyze_training_history(&epoch_history);
234    println!("\nTraining Analysis:");
235    println!("------------------");
236    for issue in analysis {
237        println!("{}", issue);
238    }
239
240    // Make predictions on validation data
241    println!("\nMaking predictions on validation data...");
242    let predictions = model.forward(&x_val_dyn)?;
243
244    // Calculate and display final metrics
245    let mse = loss_fn.forward(&predictions, &y_val_dyn)?;
246    println!("Final validation MSE: {:.6}", mse);
247
248    // Display a few sample predictions
249    println!("\nSample predictions:");
250    println!("------------------");
251    println!("  X  |  True Y  | Predicted Y ");
252    println!("---------------------------");
253
254    // Show first 5 predictions
255    let num_samples_to_show = std::cmp::min(5, x_val.shape()[0]);
256    for i in 0..num_samples_to_show {
257        println!(
258            "{:.4} | {:.4}   | {:.4}",
259            x_val[[i, 0]],
260            y_val[[i, 0]],
261            predictions[[i, 0]]
262        );
263    }
264
265    println!("\nVisualization demonstration complete!");
266    Ok(())
267}
Source

pub fn update_lr(&mut self, step: usize)

Update the learning rate based on the current step

Examples found in repository?
examples/visualize_training_progress.rs (line 175)
83fn main() -> Result<()> {
84    println!("Training Visualization Example");
85    println!("==============================\n");
86
87    // Initialize RNG with a fixed seed for reproducibility
88    let mut rng = SmallRng::seed_from_u64(42);
89
90    // Generate synthetic data
91    let num_samples = 200;
92    let (x, y) = generate_nonlinear_data(num_samples, &mut rng);
93    println!("Generated synthetic dataset with {} samples", num_samples);
94
95    // Split data into training and validation sets
96    let (x_train, y_train, x_val, y_val) = train_val_split(&x, &y, 0.8);
97    println!(
98        "Split data: {} training samples, {} validation samples",
99        x_train.shape()[0],
100        x_val.shape()[0]
101    );
102
103    // Create model
104    let mut model = create_regression_model(1, &mut rng)?;
105    println!("Created model with {} layers", model.num_layers());
106
107    // Setup loss function and optimizer
108    let loss_fn = MeanSquaredError::new();
109    let mut optimizer = Adam::new(0.01, 0.9, 0.999, 1e-8);
110
111    // Training parameters
112    let epochs = 100;
113
114    // Configure learning rate scheduler
115    let mut scheduler = StepDecay::new(
116        0.01, // Initial learning rate
117        0.5,  // Decay factor
118        30,   // Step size
119        ScheduleMethod::Epoch,
120        1e-4, // Min learning rate
121    );
122
123    // Add visualization callback
124    let mut visualization_cb =
125        VisualizationCallback::new(5) // Show every 5 epochs
126            .with_tracked_metrics(vec!["train_loss".to_string(), "val_loss".to_string()])
127            .with_plot_options(PlotOptions {
128                width: 80,
129                height: 15,
130                max_x_ticks: 10,
131                max_y_ticks: 5,
132                line_char: '─',
133                point_char: '●',
134                background_char: ' ',
135                show_grid: true,
136                show_legend: true,
137            })
138            .with_save_path("training_plot.txt");
139
140    // Train the model manually
141    let mut epoch_history = HashMap::new();
142    epoch_history.insert("train_loss".to_string(), Vec::new());
143    epoch_history.insert("val_loss".to_string(), Vec::new());
144
145    // Convert data to dynamic arrays and ensure they are owned (not views)
146    let x_train_dyn = x_train.clone().into_dyn();
147    let y_train_dyn = y_train.clone().into_dyn();
148    let x_val_dyn = x_val.clone().into_dyn();
149    let y_val_dyn = y_val.clone().into_dyn();
150
151    println!("\nStarting training with visualization...");
152
153    // Manual training loop
154    println!("\nStarting training loop...");
155    for epoch in 0..epochs {
156        // Update learning rate with scheduler
157        let current_lr = scheduler.get_lr();
158        optimizer.set_learning_rate(current_lr);
159
160        // Train for one epoch (batch size = full dataset in this example)
161        let train_loss = model.train_batch(&x_train_dyn, &y_train_dyn, &loss_fn, &mut optimizer)?;
162
163        // Compute validation loss
164        let predictions = model.forward(&x_val_dyn)?;
165        let val_loss = loss_fn.forward(&predictions, &y_val_dyn)?;
166
167        // Store metrics
168        epoch_history
169            .get_mut("train_loss")
170            .unwrap()
171            .push(train_loss);
172        epoch_history.get_mut("val_loss").unwrap().push(val_loss);
173
174        // Update the scheduler
175        scheduler.update_lr(epoch);
176
177        // Print progress
178        if epoch % 10 == 0 || epoch == epochs - 1 {
179            println!(
180                "Epoch {}/{}: train_loss = {:.6}, val_loss = {:.6}, lr = {:.6}",
181                epoch + 1,
182                epochs,
183                train_loss,
184                val_loss,
185                current_lr
186            );
187        }
188
189        // Visualize progress (manually calling the visualization callback)
190        if epoch % 5 == 0 || epoch == epochs - 1 {
191            let mut context = CallbackContext {
192                epoch,
193                total_epochs: epochs,
194                batch: 0,
195                total_batches: 1,
196                batch_loss: None,
197                epoch_loss: Some(train_loss),
198                val_loss: Some(val_loss),
199                metrics: vec![],
200                history: &epoch_history,
201                stop_training: false,
202                model: None,
203            };
204
205            visualization_cb.on_event(CallbackTiming::AfterEpoch, &mut context)?;
206        }
207    }
208
209    // Final visualization
210    let mut context = CallbackContext {
211        epoch: epochs - 1,
212        total_epochs: epochs,
213        batch: 0,
214        total_batches: 1,
215        batch_loss: None,
216        epoch_loss: Some(*epoch_history.get("train_loss").unwrap().last().unwrap()),
217        val_loss: Some(*epoch_history.get("val_loss").unwrap().last().unwrap()),
218        metrics: vec![],
219        history: &epoch_history,
220        stop_training: false,
221        model: None,
222    };
223
224    visualization_cb.on_event(CallbackTiming::AfterTraining, &mut context)?;
225
226    println!("\nTraining complete!");
227
228    // Export history to CSV
229    export_history_to_csv(&epoch_history, "training_history.csv")?;
230    println!("Exported training history to training_history.csv");
231
232    // Analyze training results
233    let analysis = analyze_training_history(&epoch_history);
234    println!("\nTraining Analysis:");
235    println!("------------------");
236    for issue in analysis {
237        println!("{}", issue);
238    }
239
240    // Make predictions on validation data
241    println!("\nMaking predictions on validation data...");
242    let predictions = model.forward(&x_val_dyn)?;
243
244    // Calculate and display final metrics
245    let mse = loss_fn.forward(&predictions, &y_val_dyn)?;
246    println!("Final validation MSE: {:.6}", mse);
247
248    // Display a few sample predictions
249    println!("\nSample predictions:");
250    println!("------------------");
251    println!("  X  |  True Y  | Predicted Y ");
252    println!("---------------------------");
253
254    // Show first 5 predictions
255    let num_samples_to_show = std::cmp::min(5, x_val.shape()[0]);
256    for i in 0..num_samples_to_show {
257        println!(
258            "{:.4} | {:.4}   | {:.4}",
259            x_val[[i, 0]],
260            y_val[[i, 0]],
261            predictions[[i, 0]]
262        );
263    }
264
265    println!("\nVisualization demonstration complete!");
266    Ok(())
267}
Source

pub fn reset_to_initial(&mut self)

Reset the learning rate to the initial value

Trait Implementations§

Source§

impl<F: Float + Debug + ScalarOperand> Callback<F> for StepDecay<F>

Source§

fn on_event( &mut self, timing: CallbackTiming, context: &mut CallbackContext<'_, F>, ) -> Result<()>

Called during training at specific points
Source§

impl<F: Float + Debug + ScalarOperand> LearningRateScheduler<F> for StepDecay<F>

Source§

fn get_learning_rate(&mut self, progress: f64) -> Result<F>

Get the learning rate for the current progress level (0.0 to 1.0)
Source§

fn reset(&mut self)

Reset the scheduler state

Auto Trait Implementations§

§

impl<F> Freeze for StepDecay<F>
where F: Freeze,

§

impl<F> RefUnwindSafe for StepDecay<F>
where F: RefUnwindSafe,

§

impl<F> Send for StepDecay<F>
where F: Send,

§

impl<F> Sync for StepDecay<F>
where F: Sync,

§

impl<F> Unpin for StepDecay<F>
where F: Unpin,

§

impl<F> UnwindSafe for StepDecay<F>
where F: UnwindSafe,

Blanket Implementations§

Source§

impl<T> Any for T
where T: 'static + ?Sized,

Source§

fn type_id(&self) -> TypeId

Gets the TypeId of self. Read more
Source§

impl<T> Borrow<T> for T
where T: ?Sized,

Source§

fn borrow(&self) -> &T

Immutably borrows from an owned value. Read more
Source§

impl<T> BorrowMut<T> for T
where T: ?Sized,

Source§

fn borrow_mut(&mut self) -> &mut T

Mutably borrows from an owned value. Read more
Source§

impl<T> From<T> for T

Source§

fn from(t: T) -> T

Returns the argument unchanged.

Source§

impl<T, U> Into<U> for T
where U: From<T>,

Source§

fn into(self) -> U

Calls U::from(self).

That is, this conversion is whatever the implementation of From<T> for U chooses to do.

Source§

impl<T> IntoEither for T

Source§

fn into_either(self, into_left: bool) -> Either<Self, Self>

Converts self into a Left variant of Either<Self, Self> if into_left is true. Converts self into a Right variant of Either<Self, Self> otherwise. Read more
Source§

fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
where F: FnOnce(&Self) -> bool,

Converts self into a Left variant of Either<Self, Self> if into_left(&self) returns true. Converts self into a Right variant of Either<Self, Self> otherwise. Read more
Source§

impl<T> Pointable for T

Source§

const ALIGN: usize

The alignment of pointer.
Source§

type Init = T

The type for initializers.
Source§

unsafe fn init(init: <T as Pointable>::Init) -> usize

Initializes a with the given initializer. Read more
Source§

unsafe fn deref<'a>(ptr: usize) -> &'a T

Dereferences the given pointer. Read more
Source§

unsafe fn deref_mut<'a>(ptr: usize) -> &'a mut T

Mutably dereferences the given pointer. Read more
Source§

unsafe fn drop(ptr: usize)

Drops the object pointed to by the given pointer. Read more
Source§

impl<T, U> TryFrom<U> for T
where U: Into<T>,

Source§

type Error = Infallible

The type returned in the event of a conversion error.
Source§

fn try_from(value: U) -> Result<T, <T as TryFrom<U>>::Error>

Performs the conversion.
Source§

impl<T, U> TryInto<U> for T
where U: TryFrom<T>,

Source§

type Error = <U as TryFrom<T>>::Error

The type returned in the event of a conversion error.
Source§

fn try_into(self) -> Result<U, <U as TryFrom<T>>::Error>

Performs the conversion.
Source§

impl<V, T> VZip<V> for T
where V: MultiLane<T>,

Source§

fn vzip(self) -> V