Struct ScheduledLSTMTrainer

Source
pub struct ScheduledLSTMTrainer<L: LossFunction, O: Optimizer, S: LearningRateScheduler> {
    pub network: LSTMNetwork,
    pub loss_function: L,
    pub optimizer: ScheduledOptimizer<O, S>,
    pub config: TrainingConfig,
    pub metrics_history: Vec<TrainingMetrics>,
}
Expand description

Specialized trainer for scheduled optimizers that automatically steps the scheduler

Fields§

§network: LSTMNetwork§loss_function: L§optimizer: ScheduledOptimizer<O, S>§config: TrainingConfig§metrics_history: Vec<TrainingMetrics>

Implementations§

Source§

impl<L: LossFunction, O: Optimizer, S: LearningRateScheduler> ScheduledLSTMTrainer<L, O, S>

Source

pub fn new( network: LSTMNetwork, loss_function: L, optimizer: ScheduledOptimizer<O, S>, ) -> Self

Source

pub fn with_config(self, config: TrainingConfig) -> Self

Source

pub fn train_sequence( &mut self, inputs: &[Array2<f64>], targets: &[Array2<f64>], ) -> f64

Train on a single sequence using backpropagation through time (BPTT)

Source

pub fn train( &mut self, train_data: &[(Vec<Array2<f64>>, Vec<Array2<f64>>)], validation_data: Option<&[(Vec<Array2<f64>>, Vec<Array2<f64>>)]>, )

Train for multiple epochs with automatic scheduler stepping

Source

pub fn evaluate(&mut self, data: &[(Vec<Array2<f64>>, Vec<Array2<f64>>)]) -> f64

Evaluate model performance on validation data

Source

pub fn predict(&mut self, inputs: &[Array2<f64>]) -> Vec<Array2<f64>>

Generate predictions for input sequences

Source

pub fn get_latest_metrics(&self) -> Option<&TrainingMetrics>

Source

pub fn get_metrics_history(&self) -> &[TrainingMetrics]

Source

pub fn set_training_mode(&mut self, training: bool)

Set network to training mode

Source

pub fn get_current_lr(&self) -> f64

Get the current learning rate

Source

pub fn get_current_epoch(&self) -> usize

Get the current epoch from the scheduler

Source

pub fn reset_optimizer(&mut self)

Reset the optimizer and scheduler

Auto Trait Implementations§

§

impl<L, O, S> Freeze for ScheduledLSTMTrainer<L, O, S>
where L: Freeze, O: Freeze, S: Freeze,

§

impl<L, O, S> RefUnwindSafe for ScheduledLSTMTrainer<L, O, S>

§

impl<L, O, S> Send for ScheduledLSTMTrainer<L, O, S>
where L: Send, O: Send, S: Send,

§

impl<L, O, S> Sync for ScheduledLSTMTrainer<L, O, S>
where L: Sync, O: Sync, S: Sync,

§

impl<L, O, S> Unpin for ScheduledLSTMTrainer<L, O, S>
where L: Unpin, O: Unpin, S: Unpin,

§

impl<L, O, S> UnwindSafe for ScheduledLSTMTrainer<L, O, S>
where L: UnwindSafe, O: UnwindSafe, S: UnwindSafe,

Blanket Implementations§

Source§

impl<T> Any for T
where T: 'static + ?Sized,

Source§

fn type_id(&self) -> TypeId

Gets the TypeId of self. Read more
Source§

impl<T> Borrow<T> for T
where T: ?Sized,

Source§

fn borrow(&self) -> &T

Immutably borrows from an owned value. Read more
Source§

impl<T> BorrowMut<T> for T
where T: ?Sized,

Source§

fn borrow_mut(&mut self) -> &mut T

Mutably borrows from an owned value. Read more
Source§

impl<T> From<T> for T

Source§

fn from(t: T) -> T

Returns the argument unchanged.

Source§

impl<T, U> Into<U> for T
where U: From<T>,

Source§

fn into(self) -> U

Calls U::from(self).

That is, this conversion is whatever the implementation of From<T> for U chooses to do.

Source§

impl<T, U> TryFrom<U> for T
where U: Into<T>,

Source§

type Error = Infallible

The type returned in the event of a conversion error.
Source§

fn try_from(value: U) -> Result<T, <T as TryFrom<U>>::Error>

Performs the conversion.
Source§

impl<T, U> TryInto<U> for T
where U: TryFrom<T>,

Source§

type Error = <U as TryFrom<T>>::Error

The type returned in the event of a conversion error.
Source§

fn try_into(self) -> Result<U, <U as TryFrom<T>>::Error>

Performs the conversion.
Source§

impl<V, T> VZip<V> for T
where V: MultiLane<T>,

Source§

fn vzip(self) -> V