pub struct LSTMTrainer<L: LossFunction, O: Optimizer> {
pub network: LSTMNetwork,
pub loss_function: L,
pub optimizer: O,
pub config: TrainingConfig,
pub metrics_history: Vec<TrainingMetrics>,
}
Expand description
Main trainer for LSTM networks with configurable loss and optimizer
Fields§
§network: LSTMNetwork
§loss_function: L
§optimizer: O
§config: TrainingConfig
§metrics_history: Vec<TrainingMetrics>
Implementations§
Source§impl<L: LossFunction, O: Optimizer> LSTMTrainer<L, O>
impl<L: LossFunction, O: Optimizer> LSTMTrainer<L, O>
pub fn new(network: LSTMNetwork, loss_function: L, optimizer: O) -> Self
pub fn with_config(self, config: TrainingConfig) -> Self
Sourcepub fn train_sequence(
&mut self,
inputs: &[Array2<f64>],
targets: &[Array2<f64>],
) -> f64
pub fn train_sequence( &mut self, inputs: &[Array2<f64>], targets: &[Array2<f64>], ) -> f64
Train on a single sequence using backpropagation through time (BPTT)
Sourcepub fn train(
&mut self,
train_data: &[(Vec<Array2<f64>>, Vec<Array2<f64>>)],
validation_data: Option<&[(Vec<Array2<f64>>, Vec<Array2<f64>>)]>,
)
pub fn train( &mut self, train_data: &[(Vec<Array2<f64>>, Vec<Array2<f64>>)], validation_data: Option<&[(Vec<Array2<f64>>, Vec<Array2<f64>>)]>, )
Train for multiple epochs with optional validation
Sourcepub fn evaluate(&mut self, data: &[(Vec<Array2<f64>>, Vec<Array2<f64>>)]) -> f64
pub fn evaluate(&mut self, data: &[(Vec<Array2<f64>>, Vec<Array2<f64>>)]) -> f64
Evaluate model performance on validation data
Sourcepub fn predict(&mut self, inputs: &[Array2<f64>]) -> Vec<Array2<f64>> ⓘ
pub fn predict(&mut self, inputs: &[Array2<f64>]) -> Vec<Array2<f64>> ⓘ
Generate predictions for input sequences
pub fn get_latest_metrics(&self) -> Option<&TrainingMetrics>
pub fn get_metrics_history(&self) -> &[TrainingMetrics]
Sourcepub fn set_training_mode(&mut self, training: bool)
pub fn set_training_mode(&mut self, training: bool)
Set network to training mode
Auto Trait Implementations§
impl<L, O> Freeze for LSTMTrainer<L, O>
impl<L, O> RefUnwindSafe for LSTMTrainer<L, O>where
L: RefUnwindSafe,
O: RefUnwindSafe,
impl<L, O> Send for LSTMTrainer<L, O>
impl<L, O> Sync for LSTMTrainer<L, O>
impl<L, O> Unpin for LSTMTrainer<L, O>
impl<L, O> UnwindSafe for LSTMTrainer<L, O>where
L: UnwindSafe,
O: UnwindSafe,
Blanket Implementations§
Source§impl<T> BorrowMut<T> for Twhere
T: ?Sized,
impl<T> BorrowMut<T> for Twhere
T: ?Sized,
Source§fn borrow_mut(&mut self) -> &mut T
fn borrow_mut(&mut self) -> &mut T
Mutably borrows from an owned value. Read more