Struct LSTMBatchTrainer

Source
pub struct LSTMBatchTrainer<L: LossFunction, O: Optimizer> {
    pub network: LSTMNetwork,
    pub loss_function: L,
    pub optimizer: O,
    pub config: TrainingConfig,
    pub metrics_history: Vec<TrainingMetrics>,
}
Expand description

Batch trainer for LSTM networks with configurable loss and optimizer Processes multiple sequences simultaneously for improved performance

Fields§

§network: LSTMNetwork§loss_function: L§optimizer: O§config: TrainingConfig§metrics_history: Vec<TrainingMetrics>

Implementations§

Source§

impl<L: LossFunction, O: Optimizer> LSTMBatchTrainer<L, O>

Source

pub fn new(network: LSTMNetwork, loss_function: L, optimizer: O) -> Self

Source

pub fn with_config(self, config: TrainingConfig) -> Self

Source

pub fn train_batch( &mut self, batch_inputs: &[Vec<Array2<f64>>], batch_targets: &[Vec<Array2<f64>>], ) -> f64

Train on a batch of sequences using batch processing

§Arguments
  • batch_inputs - Vector of input sequences, each sequence is Vec<Array2>
  • batch_targets - Vector of target sequences, each sequence is Vec<Array2>
§Returns
  • Average loss across the batch
Source

pub fn train( &mut self, train_data: &[(Vec<Array2<f64>>, Vec<Array2<f64>>)], validation_data: Option<&[(Vec<Array2<f64>>, Vec<Array2<f64>>)]>, batch_size: usize, )

Train for multiple epochs with batch processing

§Arguments
  • train_data - Vector of (input_sequences, target_sequences) tuples for training
  • validation_data - Optional validation data
  • batch_size - Number of sequences to process in each batch
Source

pub fn evaluate_batch( &mut self, data: &[(Vec<Array2<f64>>, Vec<Array2<f64>>)], batch_size: usize, ) -> f64

Evaluate model performance using batch processing

Source

pub fn predict_batch( &mut self, inputs: &[Vec<Array2<f64>>], ) -> Vec<Vec<Array2<f64>>>

Generate predictions using batch processing

Source

pub fn get_latest_metrics(&self) -> Option<&TrainingMetrics>

Source

pub fn get_metrics_history(&self) -> &[TrainingMetrics]

Source

pub fn set_training_mode(&mut self, training: bool)

Auto Trait Implementations§

§

impl<L, O> Freeze for LSTMBatchTrainer<L, O>
where L: Freeze, O: Freeze,

§

impl<L, O> RefUnwindSafe for LSTMBatchTrainer<L, O>

§

impl<L, O> Send for LSTMBatchTrainer<L, O>
where L: Send, O: Send,

§

impl<L, O> Sync for LSTMBatchTrainer<L, O>
where L: Sync, O: Sync,

§

impl<L, O> Unpin for LSTMBatchTrainer<L, O>
where L: Unpin, O: Unpin,

§

impl<L, O> UnwindSafe for LSTMBatchTrainer<L, O>
where L: UnwindSafe, O: UnwindSafe,

Blanket Implementations§

Source§

impl<T> Any for T
where T: 'static + ?Sized,

Source§

fn type_id(&self) -> TypeId

Gets the TypeId of self. Read more
Source§

impl<T> Borrow<T> for T
where T: ?Sized,

Source§

fn borrow(&self) -> &T

Immutably borrows from an owned value. Read more
Source§

impl<T> BorrowMut<T> for T
where T: ?Sized,

Source§

fn borrow_mut(&mut self) -> &mut T

Mutably borrows from an owned value. Read more
Source§

impl<T> From<T> for T

Source§

fn from(t: T) -> T

Returns the argument unchanged.

Source§

impl<T, U> Into<U> for T
where U: From<T>,

Source§

fn into(self) -> U

Calls U::from(self).

That is, this conversion is whatever the implementation of From<T> for U chooses to do.

Source§

impl<T, U> TryFrom<U> for T
where U: Into<T>,

Source§

type Error = Infallible

The type returned in the event of a conversion error.
Source§

fn try_from(value: U) -> Result<T, <T as TryFrom<U>>::Error>

Performs the conversion.
Source§

impl<T, U> TryInto<U> for T
where U: TryFrom<T>,

Source§

type Error = <U as TryFrom<T>>::Error

The type returned in the event of a conversion error.
Source§

fn try_into(self) -> Result<U, <U as TryFrom<T>>::Error>

Performs the conversion.
Source§

impl<V, T> VZip<V> for T
where V: MultiLane<T>,

Source§

fn vzip(self) -> V