Enum LearningRateSchedule

Source
pub enum LearningRateSchedule<F: Float> {
    Constant(F),
    StepDecay {
        initial_lr: F,
        decay_factor: F,
        step_size: usize,
    },
    ExponentialDecay {
        initial_lr: F,
        decay_factor: F,
    },
    Custom(Box<dyn Fn(usize) -> F>),
}
Expand description

Simple utility to generate a learning rate schedule

Variants§

§

Constant(F)

Constant learning rate

§

StepDecay

Step decay learning rate

Fields

§initial_lr: F

Initial learning rate

§decay_factor: F

Decay factor

§step_size: usize

Epochs per step

§

ExponentialDecay

Exponential decay learning rate

Fields

§initial_lr: F

Initial learning rate

§decay_factor: F

Decay factor

§

Custom(Box<dyn Fn(usize) -> F>)

Custom learning rate schedule function

Implementations§

Source§

impl<F: Float> LearningRateSchedule<F>

Source

pub fn get_learning_rate(&self, epoch: usize) -> F

Get the learning rate for a given epoch

Source

pub fn generate_schedule(&self, num_epochs: usize) -> Array1<F>

Generate the learning rate schedule for all epochs

§Arguments
  • num_epochs - Number of epochs
§Returns
  • Array1<F> - Learning rate for each epoch
Examples found in repository?
examples/training_loop_example.rs (line 92)
19fn main() -> Result<()> {
20    println!("Training loop example with visualization");
21
22    // Create dummy data
23    let n_samples = 1000;
24    let n_features = 10;
25    let n_classes = 3;
26
27    println!(
28        "Generating dummy data with {} samples, {} features, {} classes",
29        n_samples, n_features, n_classes
30    );
31
32    // Generate random features
33    let features = Array::from_shape_fn(IxDyn(&[n_samples, n_features]), |_| {
34        rand::random::<f32>() * 2.0 - 1.0
35    });
36
37    // Generate random labels (integers 0 to n_classes-1)
38    let labels = Array::from_shape_fn(IxDyn(&[n_samples, 1]), |_| {
39        (rand::random::<f32>() * n_classes as f32).floor()
40    });
41
42    // Create dataset
43    let dataset = InMemoryDataset::new(features, labels)?;
44
45    // Split into training and validation sets
46    let (train_dataset, val_dataset) = dataset.train_test_split(0.2)?;
47
48    println!(
49        "Split data into {} training samples and {} validation samples",
50        train_dataset.features.shape()[0],
51        val_dataset.features.shape()[0]
52    );
53
54    // Create transformations
55    let feature_scaler = StandardScaler::new(false);
56    let label_encoder = OneHotEncoder::new(n_classes);
57
58    // Apply transformations
59    let train_dataset = TransformedDataset::new(train_dataset)
60        .with_feature_transform(feature_scaler)
61        .with_label_transform(label_encoder);
62
63    let val_dataset = TransformedDataset::new(val_dataset)
64        .with_feature_transform(StandardScaler::new(false))
65        .with_label_transform(OneHotEncoder::new(n_classes));
66
67    // Create data loaders
68    let batch_size = 32;
69    let train_loader = DataLoader::new(train_dataset.clone(), batch_size, true, false);
70    let val_loader = DataLoader::new(val_dataset.clone(), batch_size, false, false);
71
72    println!(
73        "Created data loaders with batch size {}. Training: {} batches, Validation: {} batches",
74        batch_size,
75        train_loader.num_batches(),
76        val_loader.num_batches()
77    );
78
79    // Create model, loss, and optimizer
80    let _model = create_model(n_features, n_classes)?;
81    let _loss_fn = CrossEntropyLoss::new(1e-10);
82
83    // Create learning rate schedule
84    let lr_schedule = LearningRateSchedule::StepDecay {
85        initial_lr: 0.001,
86        decay_factor: 0.5,
87        step_size: 3,
88    };
89
90    // Generate learning rates for all epochs
91    let num_epochs = 10;
92    let learning_rates = lr_schedule.generate_schedule(num_epochs);
93    println!("Learning rate schedule:");
94    for (i, &lr) in learning_rates.iter().enumerate() {
95        println!("  Epoch {}: {:.6}", i + 1, lr);
96    }
97
98    let _optimizer = Adam::new(0.001, 0.9, 0.999, 1e-8);
99
100    // Create callbacks
101    let _checkpoint_dir = PathBuf::from("./checkpoints");
102    let tensorboard_dir = PathBuf::from("./logs");
103
104    // Create output directories if they don't exist
105    create_dir_if_not_exists("./checkpoints")?;
106    create_dir_if_not_exists("./logs")?;
107    create_dir_if_not_exists("./outputs")?;
108
109    // For this example, we'll just remove the ModelCheckpoint
110    let mut callbacks: Vec<Box<dyn scirs2_neural::callbacks::Callback<f32>>> = vec![
111        Box::new(EarlyStopping::new(5, 0.001, true)),
112        // ModelCheckpoint removed for simplicity as it requires special handling
113        Box::new(ReduceOnPlateau::new(0.001, 0.5, 3, 0.001, 0.0001)),
114        Box::new(TensorBoardLogger::new(tensorboard_dir, true, 10)),
115        // Add our visualization callback
116        Box::new(
117            VisualizationCallback::new(1)
118                .with_save_path("./outputs/training_plot.txt")
119                .with_tracked_metrics(vec![
120                    "train_loss".to_string(),
121                    "val_loss".to_string(),
122                    "accuracy".to_string(),
123                    "learning_rate".to_string(),
124                ]),
125        ),
126    ];
127
128    // Training loop
129    let mut history = HashMap::<String, Vec<f32>>::new();
130    history.insert("train_loss".to_string(), Vec::new());
131    history.insert("val_loss".to_string(), Vec::new());
132    history.insert("learning_rate".to_string(), Vec::new());
133    history.insert("accuracy".to_string(), Vec::new());
134
135    println!("Starting training for {} epochs", num_epochs);
136
137    // Run callbacks before training
138    // Create a copy of history for the context
139    let mut context_history = HashMap::<String, Vec<f32>>::new();
140    context_history.insert("train_loss".to_string(), Vec::new());
141    context_history.insert("val_loss".to_string(), Vec::new());
142    context_history.insert("learning_rate".to_string(), Vec::new());
143    context_history.insert("accuracy".to_string(), Vec::new());
144
145    // For this example, we adapt to use Vec<F> for metrics
146    // which is simpler than using Vec<(String, Option<F>)>
147    // In a real implementation, use the proper context format
148    let mut context = CallbackContext {
149        epoch: 0,
150        total_epochs: num_epochs,
151        batch: 0,
152        total_batches: train_loader.num_batches(),
153        batch_loss: None,
154        epoch_loss: None,
155        val_loss: None,
156        metrics: Vec::new(),
157        history: &context_history,
158        stop_training: false,
159        model: None,
160    };
161
162    for callback in &mut callbacks {
163        callback.on_event(CallbackTiming::BeforeTraining, &mut context)?;
164    }
165
166    // Training loop
167    for epoch in 0..num_epochs {
168        println!("Epoch {}/{}", epoch + 1, num_epochs);
169
170        // Get learning rate for this epoch
171        let learning_rate = learning_rates[epoch];
172        history
173            .get_mut("learning_rate")
174            .unwrap()
175            .push(learning_rate);
176
177        // Reset data loader
178        let mut train_loader = DataLoader::new(train_dataset.clone(), batch_size, true, false);
179        train_loader.reset();
180
181        // Update context
182        context.epoch = epoch;
183        context.epoch_loss = None;
184        context.val_loss = None;
185
186        // Run callbacks before epoch
187        for callback in &mut callbacks {
188            callback.on_event(CallbackTiming::BeforeEpoch, &mut context)?;
189        }
190
191        // Train on batches
192        let mut epoch_loss = 0.0;
193        let mut batch_count = 0;
194
195        for (batch, batch_result) in train_loader.enumerate() {
196            let (_batch_x, _batch_y) = batch_result?;
197
198            // Update context
199            context.batch = batch;
200            context.batch_loss = None;
201
202            // Run callbacks before batch
203            for callback in &mut callbacks {
204                callback.on_event(CallbackTiming::BeforeBatch, &mut context)?;
205            }
206
207            // In a real implementation, we'd train the model here
208            // For now, just compute a random loss
209            let batch_loss = rand::random::<f32>() * (1.0 / (epoch as f32 + 1.0));
210
211            // Update batch loss
212            context.batch_loss = Some(batch_loss);
213
214            // Run callbacks after batch
215            for callback in &mut callbacks {
216                callback.on_event(CallbackTiming::AfterBatch, &mut context)?;
217            }
218
219            epoch_loss += batch_loss;
220            batch_count += 1;
221        }
222
223        // Compute epoch loss
224        epoch_loss /= batch_count as f32;
225        history.get_mut("train_loss").unwrap().push(epoch_loss);
226        context.epoch_loss = Some(epoch_loss);
227
228        println!("Train loss: {:.6}", epoch_loss);
229
230        // Evaluate on validation set
231        let mut val_loss = 0.0;
232        let mut val_batch_count = 0;
233
234        let mut val_loader = DataLoader::new(val_dataset.clone(), batch_size, false, false);
235        val_loader.reset();
236
237        for batch_result in val_loader {
238            let (_batch_x, _batch_y) = batch_result?;
239
240            // In a real implementation, we'd evaluate the model here
241            // For now, just compute a random loss
242            let batch_loss = rand::random::<f32>() * (1.0 / (epoch as f32 + 1.0)) * 1.1;
243
244            val_loss += batch_loss;
245            val_batch_count += 1;
246        }
247
248        // Compute validation loss
249        val_loss /= val_batch_count as f32;
250        history.get_mut("val_loss").unwrap().push(val_loss);
251        context.val_loss = Some(val_loss);
252
253        // Simulate accuracy metric
254        let accuracy =
255            0.5 + 0.4 * (epoch as f32 / num_epochs as f32) + rand::random::<f32>() * 0.05;
256        history.get_mut("accuracy").unwrap().push(accuracy);
257
258        // Add accuracy to metrics
259        context.metrics = vec![accuracy];
260
261        println!("Validation loss: {:.6}", val_loss);
262        println!("Accuracy: {:.2}%", accuracy * 100.0);
263
264        // Run callbacks after epoch
265        for callback in &mut callbacks {
266            callback.on_event(CallbackTiming::AfterEpoch, &mut context)?;
267        }
268
269        // Check if training should be stopped
270        if context.stop_training {
271            println!("Early stopping triggered, terminating training");
272            break;
273        }
274
275        // Visualize after each epoch
276        if epoch > 0 {
277            // Plot training and validation loss
278            let loss_plot = ascii_plot(
279                &history,
280                Some("Training and Validation Loss"),
281                Some(PlotOptions {
282                    width: 80,
283                    height: 20,
284                    max_x_ticks: 10,
285                    max_y_ticks: 5,
286                    line_char: '─',
287                    point_char: '●',
288                    background_char: ' ',
289                    show_grid: true,
290                    show_legend: true,
291                }),
292            )?;
293            println!("\n{}", loss_plot);
294        }
295    }
296
297    // Run callbacks after training
298    for callback in &mut callbacks {
299        callback.on_event(CallbackTiming::AfterTraining, &mut context)?;
300    }
301
302    println!("Training complete!");
303
304    // Export metrics to CSV
305    let csv_path = "./outputs/training_history.csv";
306    export_history_to_csv(&history, csv_path)?;
307    println!("Training history exported to {}", csv_path);
308
309    // Analyze training history
310    let analysis = analyze_training_history(&history);
311    println!("\nTraining Analysis:");
312    for issue in analysis {
313        println!("  {}", issue);
314    }
315
316    // Final visualization of metrics
317    println!("\nFinal Training Metrics:\n");
318
319    // Prepare subset of metrics for separate accuracy plot
320    let mut accuracy_data = HashMap::new();
321    accuracy_data.insert(
322        "accuracy".to_string(),
323        history.get("accuracy").unwrap().clone(),
324    );
325
326    // Plot accuracy
327    let accuracy_plot = ascii_plot(
328        &accuracy_data,
329        Some("Model Accuracy"),
330        Some(PlotOptions {
331            width: 80,
332            height: 15,
333            max_x_ticks: 10,
334            max_y_ticks: 5,
335            line_char: '─',
336            point_char: '●',
337            background_char: ' ',
338            show_grid: true,
339            show_legend: true,
340        }),
341    )?;
342    println!("{}", accuracy_plot);
343
344    // Prepare subset of metrics for learning rate plot
345    let mut lr_data = HashMap::new();
346    lr_data.insert(
347        "learning_rate".to_string(),
348        history.get("learning_rate").unwrap().clone(),
349    );
350
351    // Plot learning rate
352    let lr_plot = ascii_plot(
353        &lr_data,
354        Some("Learning Rate Schedule"),
355        Some(PlotOptions {
356            width: 80,
357            height: 15,
358            max_x_ticks: 10,
359            max_y_ticks: 5,
360            line_char: '─',
361            point_char: '■',
362            background_char: ' ',
363            show_grid: true,
364            show_legend: true,
365        }),
366    )?;
367    println!("{}", lr_plot);
368
369    // Visualize both train and validation losses in a single plot
370    let mut loss_data = HashMap::new();
371    loss_data.insert(
372        "train_loss".to_string(),
373        history.get("train_loss").unwrap().clone(),
374    );
375    loss_data.insert(
376        "val_loss".to_string(),
377        history.get("val_loss").unwrap().clone(),
378    );
379
380    let loss_plot = ascii_plot(
381        &loss_data,
382        Some("Training and Validation Loss"),
383        Some(PlotOptions {
384            width: 80,
385            height: 20,
386            max_x_ticks: 10,
387            max_y_ticks: 5,
388            line_char: '─',
389            point_char: '●',
390            background_char: ' ',
391            show_grid: true,
392            show_legend: true,
393        }),
394    )?;
395    println!("{}", loss_plot);
396
397    Ok(())
398}

Auto Trait Implementations§

Blanket Implementations§

Source§

impl<T> Any for T
where T: 'static + ?Sized,

Source§

fn type_id(&self) -> TypeId

Gets the TypeId of self. Read more
Source§

impl<T> Borrow<T> for T
where T: ?Sized,

Source§

fn borrow(&self) -> &T

Immutably borrows from an owned value. Read more
Source§

impl<T> BorrowMut<T> for T
where T: ?Sized,

Source§

fn borrow_mut(&mut self) -> &mut T

Mutably borrows from an owned value. Read more
Source§

impl<T> From<T> for T

Source§

fn from(t: T) -> T

Returns the argument unchanged.

Source§

impl<T, U> Into<U> for T
where U: From<T>,

Source§

fn into(self) -> U

Calls U::from(self).

That is, this conversion is whatever the implementation of From<T> for U chooses to do.

Source§

impl<T> IntoEither for T

Source§

fn into_either(self, into_left: bool) -> Either<Self, Self>

Converts self into a Left variant of Either<Self, Self> if into_left is true. Converts self into a Right variant of Either<Self, Self> otherwise. Read more
Source§

fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
where F: FnOnce(&Self) -> bool,

Converts self into a Left variant of Either<Self, Self> if into_left(&self) returns true. Converts self into a Right variant of Either<Self, Self> otherwise. Read more
Source§

impl<T> Pointable for T

Source§

const ALIGN: usize

The alignment of pointer.
Source§

type Init = T

The type for initializers.
Source§

unsafe fn init(init: <T as Pointable>::Init) -> usize

Initializes a with the given initializer. Read more
Source§

unsafe fn deref<'a>(ptr: usize) -> &'a T

Dereferences the given pointer. Read more
Source§

unsafe fn deref_mut<'a>(ptr: usize) -> &'a mut T

Mutably dereferences the given pointer. Read more
Source§

unsafe fn drop(ptr: usize)

Drops the object pointed to by the given pointer. Read more
Source§

impl<T, U> TryFrom<U> for T
where U: Into<T>,

Source§

type Error = Infallible

The type returned in the event of a conversion error.
Source§

fn try_from(value: U) -> Result<T, <T as TryFrom<U>>::Error>

Performs the conversion.
Source§

impl<T, U> TryInto<U> for T
where U: TryFrom<T>,

Source§

type Error = <U as TryFrom<T>>::Error

The type returned in the event of a conversion error.
Source§

fn try_into(self) -> Result<U, <U as TryFrom<T>>::Error>

Performs the conversion.
Source§

impl<V, T> VZip<V> for T
where V: MultiLane<T>,

Source§

fn vzip(self) -> V