pub enum LearningRateSchedule<F: Float> {
Constant(F),
StepDecay {
initial_lr: F,
decay_factor: F,
step_size: usize,
},
ExponentialDecay {
initial_lr: F,
decay_factor: F,
},
Custom(Box<dyn Fn(usize) -> F>),
}
Expand description
Simple utility to generate a learning rate schedule
Variants§
Constant(F)
Constant learning rate
StepDecay
Step decay learning rate
ExponentialDecay
Exponential decay learning rate
Custom(Box<dyn Fn(usize) -> F>)
Custom learning rate schedule function
Implementations§
Source§impl<F: Float> LearningRateSchedule<F>
impl<F: Float> LearningRateSchedule<F>
Sourcepub fn get_learning_rate(&self, epoch: usize) -> F
pub fn get_learning_rate(&self, epoch: usize) -> F
Get the learning rate for a given epoch
Sourcepub fn generate_schedule(&self, num_epochs: usize) -> Array1<F>
pub fn generate_schedule(&self, num_epochs: usize) -> Array1<F>
Generate the learning rate schedule for all epochs
§Arguments
num_epochs
- Number of epochs
§Returns
Array1<F>
- Learning rate for each epoch
Examples found in repository?
examples/training_loop_example.rs (line 92)
19fn main() -> Result<()> {
20 println!("Training loop example with visualization");
21
22 // Create dummy data
23 let n_samples = 1000;
24 let n_features = 10;
25 let n_classes = 3;
26
27 println!(
28 "Generating dummy data with {} samples, {} features, {} classes",
29 n_samples, n_features, n_classes
30 );
31
32 // Generate random features
33 let features = Array::from_shape_fn(IxDyn(&[n_samples, n_features]), |_| {
34 rand::random::<f32>() * 2.0 - 1.0
35 });
36
37 // Generate random labels (integers 0 to n_classes-1)
38 let labels = Array::from_shape_fn(IxDyn(&[n_samples, 1]), |_| {
39 (rand::random::<f32>() * n_classes as f32).floor()
40 });
41
42 // Create dataset
43 let dataset = InMemoryDataset::new(features, labels)?;
44
45 // Split into training and validation sets
46 let (train_dataset, val_dataset) = dataset.train_test_split(0.2)?;
47
48 println!(
49 "Split data into {} training samples and {} validation samples",
50 train_dataset.features.shape()[0],
51 val_dataset.features.shape()[0]
52 );
53
54 // Create transformations
55 let feature_scaler = StandardScaler::new(false);
56 let label_encoder = OneHotEncoder::new(n_classes);
57
58 // Apply transformations
59 let train_dataset = TransformedDataset::new(train_dataset)
60 .with_feature_transform(feature_scaler)
61 .with_label_transform(label_encoder);
62
63 let val_dataset = TransformedDataset::new(val_dataset)
64 .with_feature_transform(StandardScaler::new(false))
65 .with_label_transform(OneHotEncoder::new(n_classes));
66
67 // Create data loaders
68 let batch_size = 32;
69 let train_loader = DataLoader::new(train_dataset.clone(), batch_size, true, false);
70 let val_loader = DataLoader::new(val_dataset.clone(), batch_size, false, false);
71
72 println!(
73 "Created data loaders with batch size {}. Training: {} batches, Validation: {} batches",
74 batch_size,
75 train_loader.num_batches(),
76 val_loader.num_batches()
77 );
78
79 // Create model, loss, and optimizer
80 let _model = create_model(n_features, n_classes)?;
81 let _loss_fn = CrossEntropyLoss::new(1e-10);
82
83 // Create learning rate schedule
84 let lr_schedule = LearningRateSchedule::StepDecay {
85 initial_lr: 0.001,
86 decay_factor: 0.5,
87 step_size: 3,
88 };
89
90 // Generate learning rates for all epochs
91 let num_epochs = 10;
92 let learning_rates = lr_schedule.generate_schedule(num_epochs);
93 println!("Learning rate schedule:");
94 for (i, &lr) in learning_rates.iter().enumerate() {
95 println!(" Epoch {}: {:.6}", i + 1, lr);
96 }
97
98 let _optimizer = Adam::new(0.001, 0.9, 0.999, 1e-8);
99
100 // Create callbacks
101 let _checkpoint_dir = PathBuf::from("./checkpoints");
102 let tensorboard_dir = PathBuf::from("./logs");
103
104 // Create output directories if they don't exist
105 create_dir_if_not_exists("./checkpoints")?;
106 create_dir_if_not_exists("./logs")?;
107 create_dir_if_not_exists("./outputs")?;
108
109 // For this example, we'll just remove the ModelCheckpoint
110 let mut callbacks: Vec<Box<dyn scirs2_neural::callbacks::Callback<f32>>> = vec![
111 Box::new(EarlyStopping::new(5, 0.001, true)),
112 // ModelCheckpoint removed for simplicity as it requires special handling
113 Box::new(ReduceOnPlateau::new(0.001, 0.5, 3, 0.001, 0.0001)),
114 Box::new(TensorBoardLogger::new(tensorboard_dir, true, 10)),
115 // Add our visualization callback
116 Box::new(
117 VisualizationCallback::new(1)
118 .with_save_path("./outputs/training_plot.txt")
119 .with_tracked_metrics(vec![
120 "train_loss".to_string(),
121 "val_loss".to_string(),
122 "accuracy".to_string(),
123 "learning_rate".to_string(),
124 ]),
125 ),
126 ];
127
128 // Training loop
129 let mut history = HashMap::<String, Vec<f32>>::new();
130 history.insert("train_loss".to_string(), Vec::new());
131 history.insert("val_loss".to_string(), Vec::new());
132 history.insert("learning_rate".to_string(), Vec::new());
133 history.insert("accuracy".to_string(), Vec::new());
134
135 println!("Starting training for {} epochs", num_epochs);
136
137 // Run callbacks before training
138 // Create a copy of history for the context
139 let mut context_history = HashMap::<String, Vec<f32>>::new();
140 context_history.insert("train_loss".to_string(), Vec::new());
141 context_history.insert("val_loss".to_string(), Vec::new());
142 context_history.insert("learning_rate".to_string(), Vec::new());
143 context_history.insert("accuracy".to_string(), Vec::new());
144
145 // For this example, we adapt to use Vec<F> for metrics
146 // which is simpler than using Vec<(String, Option<F>)>
147 // In a real implementation, use the proper context format
148 let mut context = CallbackContext {
149 epoch: 0,
150 total_epochs: num_epochs,
151 batch: 0,
152 total_batches: train_loader.num_batches(),
153 batch_loss: None,
154 epoch_loss: None,
155 val_loss: None,
156 metrics: Vec::new(),
157 history: &context_history,
158 stop_training: false,
159 model: None,
160 };
161
162 for callback in &mut callbacks {
163 callback.on_event(CallbackTiming::BeforeTraining, &mut context)?;
164 }
165
166 // Training loop
167 for epoch in 0..num_epochs {
168 println!("Epoch {}/{}", epoch + 1, num_epochs);
169
170 // Get learning rate for this epoch
171 let learning_rate = learning_rates[epoch];
172 history
173 .get_mut("learning_rate")
174 .unwrap()
175 .push(learning_rate);
176
177 // Reset data loader
178 let mut train_loader = DataLoader::new(train_dataset.clone(), batch_size, true, false);
179 train_loader.reset();
180
181 // Update context
182 context.epoch = epoch;
183 context.epoch_loss = None;
184 context.val_loss = None;
185
186 // Run callbacks before epoch
187 for callback in &mut callbacks {
188 callback.on_event(CallbackTiming::BeforeEpoch, &mut context)?;
189 }
190
191 // Train on batches
192 let mut epoch_loss = 0.0;
193 let mut batch_count = 0;
194
195 for (batch, batch_result) in train_loader.enumerate() {
196 let (_batch_x, _batch_y) = batch_result?;
197
198 // Update context
199 context.batch = batch;
200 context.batch_loss = None;
201
202 // Run callbacks before batch
203 for callback in &mut callbacks {
204 callback.on_event(CallbackTiming::BeforeBatch, &mut context)?;
205 }
206
207 // In a real implementation, we'd train the model here
208 // For now, just compute a random loss
209 let batch_loss = rand::random::<f32>() * (1.0 / (epoch as f32 + 1.0));
210
211 // Update batch loss
212 context.batch_loss = Some(batch_loss);
213
214 // Run callbacks after batch
215 for callback in &mut callbacks {
216 callback.on_event(CallbackTiming::AfterBatch, &mut context)?;
217 }
218
219 epoch_loss += batch_loss;
220 batch_count += 1;
221 }
222
223 // Compute epoch loss
224 epoch_loss /= batch_count as f32;
225 history.get_mut("train_loss").unwrap().push(epoch_loss);
226 context.epoch_loss = Some(epoch_loss);
227
228 println!("Train loss: {:.6}", epoch_loss);
229
230 // Evaluate on validation set
231 let mut val_loss = 0.0;
232 let mut val_batch_count = 0;
233
234 let mut val_loader = DataLoader::new(val_dataset.clone(), batch_size, false, false);
235 val_loader.reset();
236
237 for batch_result in val_loader {
238 let (_batch_x, _batch_y) = batch_result?;
239
240 // In a real implementation, we'd evaluate the model here
241 // For now, just compute a random loss
242 let batch_loss = rand::random::<f32>() * (1.0 / (epoch as f32 + 1.0)) * 1.1;
243
244 val_loss += batch_loss;
245 val_batch_count += 1;
246 }
247
248 // Compute validation loss
249 val_loss /= val_batch_count as f32;
250 history.get_mut("val_loss").unwrap().push(val_loss);
251 context.val_loss = Some(val_loss);
252
253 // Simulate accuracy metric
254 let accuracy =
255 0.5 + 0.4 * (epoch as f32 / num_epochs as f32) + rand::random::<f32>() * 0.05;
256 history.get_mut("accuracy").unwrap().push(accuracy);
257
258 // Add accuracy to metrics
259 context.metrics = vec![accuracy];
260
261 println!("Validation loss: {:.6}", val_loss);
262 println!("Accuracy: {:.2}%", accuracy * 100.0);
263
264 // Run callbacks after epoch
265 for callback in &mut callbacks {
266 callback.on_event(CallbackTiming::AfterEpoch, &mut context)?;
267 }
268
269 // Check if training should be stopped
270 if context.stop_training {
271 println!("Early stopping triggered, terminating training");
272 break;
273 }
274
275 // Visualize after each epoch
276 if epoch > 0 {
277 // Plot training and validation loss
278 let loss_plot = ascii_plot(
279 &history,
280 Some("Training and Validation Loss"),
281 Some(PlotOptions {
282 width: 80,
283 height: 20,
284 max_x_ticks: 10,
285 max_y_ticks: 5,
286 line_char: '─',
287 point_char: '●',
288 background_char: ' ',
289 show_grid: true,
290 show_legend: true,
291 }),
292 )?;
293 println!("\n{}", loss_plot);
294 }
295 }
296
297 // Run callbacks after training
298 for callback in &mut callbacks {
299 callback.on_event(CallbackTiming::AfterTraining, &mut context)?;
300 }
301
302 println!("Training complete!");
303
304 // Export metrics to CSV
305 let csv_path = "./outputs/training_history.csv";
306 export_history_to_csv(&history, csv_path)?;
307 println!("Training history exported to {}", csv_path);
308
309 // Analyze training history
310 let analysis = analyze_training_history(&history);
311 println!("\nTraining Analysis:");
312 for issue in analysis {
313 println!(" {}", issue);
314 }
315
316 // Final visualization of metrics
317 println!("\nFinal Training Metrics:\n");
318
319 // Prepare subset of metrics for separate accuracy plot
320 let mut accuracy_data = HashMap::new();
321 accuracy_data.insert(
322 "accuracy".to_string(),
323 history.get("accuracy").unwrap().clone(),
324 );
325
326 // Plot accuracy
327 let accuracy_plot = ascii_plot(
328 &accuracy_data,
329 Some("Model Accuracy"),
330 Some(PlotOptions {
331 width: 80,
332 height: 15,
333 max_x_ticks: 10,
334 max_y_ticks: 5,
335 line_char: '─',
336 point_char: '●',
337 background_char: ' ',
338 show_grid: true,
339 show_legend: true,
340 }),
341 )?;
342 println!("{}", accuracy_plot);
343
344 // Prepare subset of metrics for learning rate plot
345 let mut lr_data = HashMap::new();
346 lr_data.insert(
347 "learning_rate".to_string(),
348 history.get("learning_rate").unwrap().clone(),
349 );
350
351 // Plot learning rate
352 let lr_plot = ascii_plot(
353 &lr_data,
354 Some("Learning Rate Schedule"),
355 Some(PlotOptions {
356 width: 80,
357 height: 15,
358 max_x_ticks: 10,
359 max_y_ticks: 5,
360 line_char: '─',
361 point_char: '■',
362 background_char: ' ',
363 show_grid: true,
364 show_legend: true,
365 }),
366 )?;
367 println!("{}", lr_plot);
368
369 // Visualize both train and validation losses in a single plot
370 let mut loss_data = HashMap::new();
371 loss_data.insert(
372 "train_loss".to_string(),
373 history.get("train_loss").unwrap().clone(),
374 );
375 loss_data.insert(
376 "val_loss".to_string(),
377 history.get("val_loss").unwrap().clone(),
378 );
379
380 let loss_plot = ascii_plot(
381 &loss_data,
382 Some("Training and Validation Loss"),
383 Some(PlotOptions {
384 width: 80,
385 height: 20,
386 max_x_ticks: 10,
387 max_y_ticks: 5,
388 line_char: '─',
389 point_char: '●',
390 background_char: ' ',
391 show_grid: true,
392 show_legend: true,
393 }),
394 )?;
395 println!("{}", loss_plot);
396
397 Ok(())
398}
Auto Trait Implementations§
impl<F> Freeze for LearningRateSchedule<F>where
F: Freeze,
impl<F> !RefUnwindSafe for LearningRateSchedule<F>
impl<F> !Send for LearningRateSchedule<F>
impl<F> !Sync for LearningRateSchedule<F>
impl<F> Unpin for LearningRateSchedule<F>where
F: Unpin,
impl<F> !UnwindSafe for LearningRateSchedule<F>
Blanket Implementations§
Source§impl<T> BorrowMut<T> for Twhere
T: ?Sized,
impl<T> BorrowMut<T> for Twhere
T: ?Sized,
Source§fn borrow_mut(&mut self) -> &mut T
fn borrow_mut(&mut self) -> &mut T
Mutably borrows from an owned value. Read more
Source§impl<T> IntoEither for T
impl<T> IntoEither for T
Source§fn into_either(self, into_left: bool) -> Either<Self, Self>
fn into_either(self, into_left: bool) -> Either<Self, Self>
Converts
self
into a Left
variant of Either<Self, Self>
if into_left
is true
.
Converts self
into a Right
variant of Either<Self, Self>
otherwise. Read moreSource§fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
Converts
self
into a Left
variant of Either<Self, Self>
if into_left(&self)
returns true
.
Converts self
into a Right
variant of Either<Self, Self>
otherwise. Read more