pub struct StepDecay<F: Float + Debug + ScalarOperand> { /* private fields */ }
Expand description
Step decay learning rate scheduler
Reduces the learning rate by a factor every n steps (epochs or batches).
Implementations§
Source§impl<F: Float + Debug + ScalarOperand> StepDecay<F>
impl<F: Float + Debug + ScalarOperand> StepDecay<F>
Sourcepub fn new(
initial_lr: F,
factor: F,
step_size: usize,
method: ScheduleMethod,
min_lr: F,
) -> Self
pub fn new( initial_lr: F, factor: F, step_size: usize, method: ScheduleMethod, min_lr: F, ) -> Self
Create a new step decay scheduler
§Arguments
initial_lr
- Initial learning ratefactor
- Factor to multiply learning rate by (should be less than 1.0)step_size
- Number of steps between reductionsmethod
- Whether to schedule on epoch or batchmin_lr
- Minimum learning rate
Examples found in repository?
examples/visualize_training_progress.rs (lines 115-121)
83fn main() -> Result<()> {
84 println!("Training Visualization Example");
85 println!("==============================\n");
86
87 // Initialize RNG with a fixed seed for reproducibility
88 let mut rng = SmallRng::seed_from_u64(42);
89
90 // Generate synthetic data
91 let num_samples = 200;
92 let (x, y) = generate_nonlinear_data(num_samples, &mut rng);
93 println!("Generated synthetic dataset with {} samples", num_samples);
94
95 // Split data into training and validation sets
96 let (x_train, y_train, x_val, y_val) = train_val_split(&x, &y, 0.8);
97 println!(
98 "Split data: {} training samples, {} validation samples",
99 x_train.shape()[0],
100 x_val.shape()[0]
101 );
102
103 // Create model
104 let mut model = create_regression_model(1, &mut rng)?;
105 println!("Created model with {} layers", model.num_layers());
106
107 // Setup loss function and optimizer
108 let loss_fn = MeanSquaredError::new();
109 let mut optimizer = Adam::new(0.01, 0.9, 0.999, 1e-8);
110
111 // Training parameters
112 let epochs = 100;
113
114 // Configure learning rate scheduler
115 let mut scheduler = StepDecay::new(
116 0.01, // Initial learning rate
117 0.5, // Decay factor
118 30, // Step size
119 ScheduleMethod::Epoch,
120 1e-4, // Min learning rate
121 );
122
123 // Add visualization callback
124 let mut visualization_cb =
125 VisualizationCallback::new(5) // Show every 5 epochs
126 .with_tracked_metrics(vec!["train_loss".to_string(), "val_loss".to_string()])
127 .with_plot_options(PlotOptions {
128 width: 80,
129 height: 15,
130 max_x_ticks: 10,
131 max_y_ticks: 5,
132 line_char: '─',
133 point_char: '●',
134 background_char: ' ',
135 show_grid: true,
136 show_legend: true,
137 })
138 .with_save_path("training_plot.txt");
139
140 // Train the model manually
141 let mut epoch_history = HashMap::new();
142 epoch_history.insert("train_loss".to_string(), Vec::new());
143 epoch_history.insert("val_loss".to_string(), Vec::new());
144
145 // Convert data to dynamic arrays and ensure they are owned (not views)
146 let x_train_dyn = x_train.clone().into_dyn();
147 let y_train_dyn = y_train.clone().into_dyn();
148 let x_val_dyn = x_val.clone().into_dyn();
149 let y_val_dyn = y_val.clone().into_dyn();
150
151 println!("\nStarting training with visualization...");
152
153 // Manual training loop
154 println!("\nStarting training loop...");
155 for epoch in 0..epochs {
156 // Update learning rate with scheduler
157 let current_lr = scheduler.get_lr();
158 optimizer.set_learning_rate(current_lr);
159
160 // Train for one epoch (batch size = full dataset in this example)
161 let train_loss = model.train_batch(&x_train_dyn, &y_train_dyn, &loss_fn, &mut optimizer)?;
162
163 // Compute validation loss
164 let predictions = model.forward(&x_val_dyn)?;
165 let val_loss = loss_fn.forward(&predictions, &y_val_dyn)?;
166
167 // Store metrics
168 epoch_history
169 .get_mut("train_loss")
170 .unwrap()
171 .push(train_loss);
172 epoch_history.get_mut("val_loss").unwrap().push(val_loss);
173
174 // Update the scheduler
175 scheduler.update_lr(epoch);
176
177 // Print progress
178 if epoch % 10 == 0 || epoch == epochs - 1 {
179 println!(
180 "Epoch {}/{}: train_loss = {:.6}, val_loss = {:.6}, lr = {:.6}",
181 epoch + 1,
182 epochs,
183 train_loss,
184 val_loss,
185 current_lr
186 );
187 }
188
189 // Visualize progress (manually calling the visualization callback)
190 if epoch % 5 == 0 || epoch == epochs - 1 {
191 let mut context = CallbackContext {
192 epoch,
193 total_epochs: epochs,
194 batch: 0,
195 total_batches: 1,
196 batch_loss: None,
197 epoch_loss: Some(train_loss),
198 val_loss: Some(val_loss),
199 metrics: vec![],
200 history: &epoch_history,
201 stop_training: false,
202 model: None,
203 };
204
205 visualization_cb.on_event(CallbackTiming::AfterEpoch, &mut context)?;
206 }
207 }
208
209 // Final visualization
210 let mut context = CallbackContext {
211 epoch: epochs - 1,
212 total_epochs: epochs,
213 batch: 0,
214 total_batches: 1,
215 batch_loss: None,
216 epoch_loss: Some(*epoch_history.get("train_loss").unwrap().last().unwrap()),
217 val_loss: Some(*epoch_history.get("val_loss").unwrap().last().unwrap()),
218 metrics: vec![],
219 history: &epoch_history,
220 stop_training: false,
221 model: None,
222 };
223
224 visualization_cb.on_event(CallbackTiming::AfterTraining, &mut context)?;
225
226 println!("\nTraining complete!");
227
228 // Export history to CSV
229 export_history_to_csv(&epoch_history, "training_history.csv")?;
230 println!("Exported training history to training_history.csv");
231
232 // Analyze training results
233 let analysis = analyze_training_history(&epoch_history);
234 println!("\nTraining Analysis:");
235 println!("------------------");
236 for issue in analysis {
237 println!("{}", issue);
238 }
239
240 // Make predictions on validation data
241 println!("\nMaking predictions on validation data...");
242 let predictions = model.forward(&x_val_dyn)?;
243
244 // Calculate and display final metrics
245 let mse = loss_fn.forward(&predictions, &y_val_dyn)?;
246 println!("Final validation MSE: {:.6}", mse);
247
248 // Display a few sample predictions
249 println!("\nSample predictions:");
250 println!("------------------");
251 println!(" X | True Y | Predicted Y ");
252 println!("---------------------------");
253
254 // Show first 5 predictions
255 let num_samples_to_show = std::cmp::min(5, x_val.shape()[0]);
256 for i in 0..num_samples_to_show {
257 println!(
258 "{:.4} | {:.4} | {:.4}",
259 x_val[[i, 0]],
260 y_val[[i, 0]],
261 predictions[[i, 0]]
262 );
263 }
264
265 println!("\nVisualization demonstration complete!");
266 Ok(())
267}
Sourcepub fn get_initial_lr(&self) -> F
pub fn get_initial_lr(&self) -> F
Get the initial learning rate
Sourcepub fn get_lr(&self) -> F
pub fn get_lr(&self) -> F
Get the current learning rate
Examples found in repository?
examples/visualize_training_progress.rs (line 157)
83fn main() -> Result<()> {
84 println!("Training Visualization Example");
85 println!("==============================\n");
86
87 // Initialize RNG with a fixed seed for reproducibility
88 let mut rng = SmallRng::seed_from_u64(42);
89
90 // Generate synthetic data
91 let num_samples = 200;
92 let (x, y) = generate_nonlinear_data(num_samples, &mut rng);
93 println!("Generated synthetic dataset with {} samples", num_samples);
94
95 // Split data into training and validation sets
96 let (x_train, y_train, x_val, y_val) = train_val_split(&x, &y, 0.8);
97 println!(
98 "Split data: {} training samples, {} validation samples",
99 x_train.shape()[0],
100 x_val.shape()[0]
101 );
102
103 // Create model
104 let mut model = create_regression_model(1, &mut rng)?;
105 println!("Created model with {} layers", model.num_layers());
106
107 // Setup loss function and optimizer
108 let loss_fn = MeanSquaredError::new();
109 let mut optimizer = Adam::new(0.01, 0.9, 0.999, 1e-8);
110
111 // Training parameters
112 let epochs = 100;
113
114 // Configure learning rate scheduler
115 let mut scheduler = StepDecay::new(
116 0.01, // Initial learning rate
117 0.5, // Decay factor
118 30, // Step size
119 ScheduleMethod::Epoch,
120 1e-4, // Min learning rate
121 );
122
123 // Add visualization callback
124 let mut visualization_cb =
125 VisualizationCallback::new(5) // Show every 5 epochs
126 .with_tracked_metrics(vec!["train_loss".to_string(), "val_loss".to_string()])
127 .with_plot_options(PlotOptions {
128 width: 80,
129 height: 15,
130 max_x_ticks: 10,
131 max_y_ticks: 5,
132 line_char: '─',
133 point_char: '●',
134 background_char: ' ',
135 show_grid: true,
136 show_legend: true,
137 })
138 .with_save_path("training_plot.txt");
139
140 // Train the model manually
141 let mut epoch_history = HashMap::new();
142 epoch_history.insert("train_loss".to_string(), Vec::new());
143 epoch_history.insert("val_loss".to_string(), Vec::new());
144
145 // Convert data to dynamic arrays and ensure they are owned (not views)
146 let x_train_dyn = x_train.clone().into_dyn();
147 let y_train_dyn = y_train.clone().into_dyn();
148 let x_val_dyn = x_val.clone().into_dyn();
149 let y_val_dyn = y_val.clone().into_dyn();
150
151 println!("\nStarting training with visualization...");
152
153 // Manual training loop
154 println!("\nStarting training loop...");
155 for epoch in 0..epochs {
156 // Update learning rate with scheduler
157 let current_lr = scheduler.get_lr();
158 optimizer.set_learning_rate(current_lr);
159
160 // Train for one epoch (batch size = full dataset in this example)
161 let train_loss = model.train_batch(&x_train_dyn, &y_train_dyn, &loss_fn, &mut optimizer)?;
162
163 // Compute validation loss
164 let predictions = model.forward(&x_val_dyn)?;
165 let val_loss = loss_fn.forward(&predictions, &y_val_dyn)?;
166
167 // Store metrics
168 epoch_history
169 .get_mut("train_loss")
170 .unwrap()
171 .push(train_loss);
172 epoch_history.get_mut("val_loss").unwrap().push(val_loss);
173
174 // Update the scheduler
175 scheduler.update_lr(epoch);
176
177 // Print progress
178 if epoch % 10 == 0 || epoch == epochs - 1 {
179 println!(
180 "Epoch {}/{}: train_loss = {:.6}, val_loss = {:.6}, lr = {:.6}",
181 epoch + 1,
182 epochs,
183 train_loss,
184 val_loss,
185 current_lr
186 );
187 }
188
189 // Visualize progress (manually calling the visualization callback)
190 if epoch % 5 == 0 || epoch == epochs - 1 {
191 let mut context = CallbackContext {
192 epoch,
193 total_epochs: epochs,
194 batch: 0,
195 total_batches: 1,
196 batch_loss: None,
197 epoch_loss: Some(train_loss),
198 val_loss: Some(val_loss),
199 metrics: vec![],
200 history: &epoch_history,
201 stop_training: false,
202 model: None,
203 };
204
205 visualization_cb.on_event(CallbackTiming::AfterEpoch, &mut context)?;
206 }
207 }
208
209 // Final visualization
210 let mut context = CallbackContext {
211 epoch: epochs - 1,
212 total_epochs: epochs,
213 batch: 0,
214 total_batches: 1,
215 batch_loss: None,
216 epoch_loss: Some(*epoch_history.get("train_loss").unwrap().last().unwrap()),
217 val_loss: Some(*epoch_history.get("val_loss").unwrap().last().unwrap()),
218 metrics: vec![],
219 history: &epoch_history,
220 stop_training: false,
221 model: None,
222 };
223
224 visualization_cb.on_event(CallbackTiming::AfterTraining, &mut context)?;
225
226 println!("\nTraining complete!");
227
228 // Export history to CSV
229 export_history_to_csv(&epoch_history, "training_history.csv")?;
230 println!("Exported training history to training_history.csv");
231
232 // Analyze training results
233 let analysis = analyze_training_history(&epoch_history);
234 println!("\nTraining Analysis:");
235 println!("------------------");
236 for issue in analysis {
237 println!("{}", issue);
238 }
239
240 // Make predictions on validation data
241 println!("\nMaking predictions on validation data...");
242 let predictions = model.forward(&x_val_dyn)?;
243
244 // Calculate and display final metrics
245 let mse = loss_fn.forward(&predictions, &y_val_dyn)?;
246 println!("Final validation MSE: {:.6}", mse);
247
248 // Display a few sample predictions
249 println!("\nSample predictions:");
250 println!("------------------");
251 println!(" X | True Y | Predicted Y ");
252 println!("---------------------------");
253
254 // Show first 5 predictions
255 let num_samples_to_show = std::cmp::min(5, x_val.shape()[0]);
256 for i in 0..num_samples_to_show {
257 println!(
258 "{:.4} | {:.4} | {:.4}",
259 x_val[[i, 0]],
260 y_val[[i, 0]],
261 predictions[[i, 0]]
262 );
263 }
264
265 println!("\nVisualization demonstration complete!");
266 Ok(())
267}
Sourcepub fn update_lr(&mut self, step: usize)
pub fn update_lr(&mut self, step: usize)
Update the learning rate based on the current step
Examples found in repository?
examples/visualize_training_progress.rs (line 175)
83fn main() -> Result<()> {
84 println!("Training Visualization Example");
85 println!("==============================\n");
86
87 // Initialize RNG with a fixed seed for reproducibility
88 let mut rng = SmallRng::seed_from_u64(42);
89
90 // Generate synthetic data
91 let num_samples = 200;
92 let (x, y) = generate_nonlinear_data(num_samples, &mut rng);
93 println!("Generated synthetic dataset with {} samples", num_samples);
94
95 // Split data into training and validation sets
96 let (x_train, y_train, x_val, y_val) = train_val_split(&x, &y, 0.8);
97 println!(
98 "Split data: {} training samples, {} validation samples",
99 x_train.shape()[0],
100 x_val.shape()[0]
101 );
102
103 // Create model
104 let mut model = create_regression_model(1, &mut rng)?;
105 println!("Created model with {} layers", model.num_layers());
106
107 // Setup loss function and optimizer
108 let loss_fn = MeanSquaredError::new();
109 let mut optimizer = Adam::new(0.01, 0.9, 0.999, 1e-8);
110
111 // Training parameters
112 let epochs = 100;
113
114 // Configure learning rate scheduler
115 let mut scheduler = StepDecay::new(
116 0.01, // Initial learning rate
117 0.5, // Decay factor
118 30, // Step size
119 ScheduleMethod::Epoch,
120 1e-4, // Min learning rate
121 );
122
123 // Add visualization callback
124 let mut visualization_cb =
125 VisualizationCallback::new(5) // Show every 5 epochs
126 .with_tracked_metrics(vec!["train_loss".to_string(), "val_loss".to_string()])
127 .with_plot_options(PlotOptions {
128 width: 80,
129 height: 15,
130 max_x_ticks: 10,
131 max_y_ticks: 5,
132 line_char: '─',
133 point_char: '●',
134 background_char: ' ',
135 show_grid: true,
136 show_legend: true,
137 })
138 .with_save_path("training_plot.txt");
139
140 // Train the model manually
141 let mut epoch_history = HashMap::new();
142 epoch_history.insert("train_loss".to_string(), Vec::new());
143 epoch_history.insert("val_loss".to_string(), Vec::new());
144
145 // Convert data to dynamic arrays and ensure they are owned (not views)
146 let x_train_dyn = x_train.clone().into_dyn();
147 let y_train_dyn = y_train.clone().into_dyn();
148 let x_val_dyn = x_val.clone().into_dyn();
149 let y_val_dyn = y_val.clone().into_dyn();
150
151 println!("\nStarting training with visualization...");
152
153 // Manual training loop
154 println!("\nStarting training loop...");
155 for epoch in 0..epochs {
156 // Update learning rate with scheduler
157 let current_lr = scheduler.get_lr();
158 optimizer.set_learning_rate(current_lr);
159
160 // Train for one epoch (batch size = full dataset in this example)
161 let train_loss = model.train_batch(&x_train_dyn, &y_train_dyn, &loss_fn, &mut optimizer)?;
162
163 // Compute validation loss
164 let predictions = model.forward(&x_val_dyn)?;
165 let val_loss = loss_fn.forward(&predictions, &y_val_dyn)?;
166
167 // Store metrics
168 epoch_history
169 .get_mut("train_loss")
170 .unwrap()
171 .push(train_loss);
172 epoch_history.get_mut("val_loss").unwrap().push(val_loss);
173
174 // Update the scheduler
175 scheduler.update_lr(epoch);
176
177 // Print progress
178 if epoch % 10 == 0 || epoch == epochs - 1 {
179 println!(
180 "Epoch {}/{}: train_loss = {:.6}, val_loss = {:.6}, lr = {:.6}",
181 epoch + 1,
182 epochs,
183 train_loss,
184 val_loss,
185 current_lr
186 );
187 }
188
189 // Visualize progress (manually calling the visualization callback)
190 if epoch % 5 == 0 || epoch == epochs - 1 {
191 let mut context = CallbackContext {
192 epoch,
193 total_epochs: epochs,
194 batch: 0,
195 total_batches: 1,
196 batch_loss: None,
197 epoch_loss: Some(train_loss),
198 val_loss: Some(val_loss),
199 metrics: vec![],
200 history: &epoch_history,
201 stop_training: false,
202 model: None,
203 };
204
205 visualization_cb.on_event(CallbackTiming::AfterEpoch, &mut context)?;
206 }
207 }
208
209 // Final visualization
210 let mut context = CallbackContext {
211 epoch: epochs - 1,
212 total_epochs: epochs,
213 batch: 0,
214 total_batches: 1,
215 batch_loss: None,
216 epoch_loss: Some(*epoch_history.get("train_loss").unwrap().last().unwrap()),
217 val_loss: Some(*epoch_history.get("val_loss").unwrap().last().unwrap()),
218 metrics: vec![],
219 history: &epoch_history,
220 stop_training: false,
221 model: None,
222 };
223
224 visualization_cb.on_event(CallbackTiming::AfterTraining, &mut context)?;
225
226 println!("\nTraining complete!");
227
228 // Export history to CSV
229 export_history_to_csv(&epoch_history, "training_history.csv")?;
230 println!("Exported training history to training_history.csv");
231
232 // Analyze training results
233 let analysis = analyze_training_history(&epoch_history);
234 println!("\nTraining Analysis:");
235 println!("------------------");
236 for issue in analysis {
237 println!("{}", issue);
238 }
239
240 // Make predictions on validation data
241 println!("\nMaking predictions on validation data...");
242 let predictions = model.forward(&x_val_dyn)?;
243
244 // Calculate and display final metrics
245 let mse = loss_fn.forward(&predictions, &y_val_dyn)?;
246 println!("Final validation MSE: {:.6}", mse);
247
248 // Display a few sample predictions
249 println!("\nSample predictions:");
250 println!("------------------");
251 println!(" X | True Y | Predicted Y ");
252 println!("---------------------------");
253
254 // Show first 5 predictions
255 let num_samples_to_show = std::cmp::min(5, x_val.shape()[0]);
256 for i in 0..num_samples_to_show {
257 println!(
258 "{:.4} | {:.4} | {:.4}",
259 x_val[[i, 0]],
260 y_val[[i, 0]],
261 predictions[[i, 0]]
262 );
263 }
264
265 println!("\nVisualization demonstration complete!");
266 Ok(())
267}
Sourcepub fn reset_to_initial(&mut self)
pub fn reset_to_initial(&mut self)
Reset the learning rate to the initial value
Trait Implementations§
Source§impl<F: Float + Debug + ScalarOperand> Callback<F> for StepDecay<F>
impl<F: Float + Debug + ScalarOperand> Callback<F> for StepDecay<F>
Source§fn on_event(
&mut self,
timing: CallbackTiming,
context: &mut CallbackContext<'_, F>,
) -> Result<()>
fn on_event( &mut self, timing: CallbackTiming, context: &mut CallbackContext<'_, F>, ) -> Result<()>
Called during training at specific points
Source§impl<F: Float + Debug + ScalarOperand> LearningRateScheduler<F> for StepDecay<F>
impl<F: Float + Debug + ScalarOperand> LearningRateScheduler<F> for StepDecay<F>
Auto Trait Implementations§
impl<F> Freeze for StepDecay<F>where
F: Freeze,
impl<F> RefUnwindSafe for StepDecay<F>where
F: RefUnwindSafe,
impl<F> Send for StepDecay<F>where
F: Send,
impl<F> Sync for StepDecay<F>where
F: Sync,
impl<F> Unpin for StepDecay<F>where
F: Unpin,
impl<F> UnwindSafe for StepDecay<F>where
F: UnwindSafe,
Blanket Implementations§
Source§impl<T> BorrowMut<T> for Twhere
T: ?Sized,
impl<T> BorrowMut<T> for Twhere
T: ?Sized,
Source§fn borrow_mut(&mut self) -> &mut T
fn borrow_mut(&mut self) -> &mut T
Mutably borrows from an owned value. Read more
Source§impl<T> IntoEither for T
impl<T> IntoEither for T
Source§fn into_either(self, into_left: bool) -> Either<Self, Self>
fn into_either(self, into_left: bool) -> Either<Self, Self>
Converts
self
into a Left
variant of Either<Self, Self>
if into_left
is true
.
Converts self
into a Right
variant of Either<Self, Self>
otherwise. Read moreSource§fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
Converts
self
into a Left
variant of Either<Self, Self>
if into_left(&self)
returns true
.
Converts self
into a Right
variant of Either<Self, Self>
otherwise. Read more