pub struct ScheduledLSTMTrainer<L: LossFunction, O: Optimizer, S: LearningRateScheduler> {
pub network: LSTMNetwork,
pub loss_function: L,
pub optimizer: ScheduledOptimizer<O, S>,
pub config: TrainingConfig,
pub metrics_history: Vec<TrainingMetrics>,
}
Expand description
Specialized trainer for scheduled optimizers that automatically steps the scheduler
Fields§
§network: LSTMNetwork
§loss_function: L
§optimizer: ScheduledOptimizer<O, S>
§config: TrainingConfig
§metrics_history: Vec<TrainingMetrics>
Implementations§
Source§impl<L: LossFunction, O: Optimizer, S: LearningRateScheduler> ScheduledLSTMTrainer<L, O, S>
impl<L: LossFunction, O: Optimizer, S: LearningRateScheduler> ScheduledLSTMTrainer<L, O, S>
pub fn new( network: LSTMNetwork, loss_function: L, optimizer: ScheduledOptimizer<O, S>, ) -> Self
pub fn with_config(self, config: TrainingConfig) -> Self
Sourcepub fn train_sequence(
&mut self,
inputs: &[Array2<f64>],
targets: &[Array2<f64>],
) -> f64
pub fn train_sequence( &mut self, inputs: &[Array2<f64>], targets: &[Array2<f64>], ) -> f64
Train on a single sequence using backpropagation through time (BPTT)
Sourcepub fn train(
&mut self,
train_data: &[(Vec<Array2<f64>>, Vec<Array2<f64>>)],
validation_data: Option<&[(Vec<Array2<f64>>, Vec<Array2<f64>>)]>,
)
pub fn train( &mut self, train_data: &[(Vec<Array2<f64>>, Vec<Array2<f64>>)], validation_data: Option<&[(Vec<Array2<f64>>, Vec<Array2<f64>>)]>, )
Train for multiple epochs with automatic scheduler stepping
Sourcepub fn evaluate(&mut self, data: &[(Vec<Array2<f64>>, Vec<Array2<f64>>)]) -> f64
pub fn evaluate(&mut self, data: &[(Vec<Array2<f64>>, Vec<Array2<f64>>)]) -> f64
Evaluate model performance on validation data
Sourcepub fn predict(&mut self, inputs: &[Array2<f64>]) -> Vec<Array2<f64>> ⓘ
pub fn predict(&mut self, inputs: &[Array2<f64>]) -> Vec<Array2<f64>> ⓘ
Generate predictions for input sequences
pub fn get_latest_metrics(&self) -> Option<&TrainingMetrics>
pub fn get_metrics_history(&self) -> &[TrainingMetrics]
Sourcepub fn set_training_mode(&mut self, training: bool)
pub fn set_training_mode(&mut self, training: bool)
Set network to training mode
Sourcepub fn get_current_lr(&self) -> f64
pub fn get_current_lr(&self) -> f64
Get the current learning rate
Sourcepub fn get_current_epoch(&self) -> usize
pub fn get_current_epoch(&self) -> usize
Get the current epoch from the scheduler
Sourcepub fn reset_optimizer(&mut self)
pub fn reset_optimizer(&mut self)
Reset the optimizer and scheduler
Auto Trait Implementations§
impl<L, O, S> Freeze for ScheduledLSTMTrainer<L, O, S>
impl<L, O, S> RefUnwindSafe for ScheduledLSTMTrainer<L, O, S>
impl<L, O, S> Send for ScheduledLSTMTrainer<L, O, S>
impl<L, O, S> Sync for ScheduledLSTMTrainer<L, O, S>
impl<L, O, S> Unpin for ScheduledLSTMTrainer<L, O, S>
impl<L, O, S> UnwindSafe for ScheduledLSTMTrainer<L, O, S>
Blanket Implementations§
Source§impl<T> BorrowMut<T> for Twhere
T: ?Sized,
impl<T> BorrowMut<T> for Twhere
T: ?Sized,
Source§fn borrow_mut(&mut self) -> &mut T
fn borrow_mut(&mut self) -> &mut T
Mutably borrows from an owned value. Read more