Struct LSTMNetwork

Source
pub struct LSTMNetwork {
    pub input_size: usize,
    pub hidden_size: usize,
    pub num_layers: usize,
    pub is_training: bool,
    /* private fields */
}
Expand description

Multi-layer LSTM network for sequence modeling with dropout support

Stacks multiple LSTM cells where the output of layer i becomes the input to layer i+1. Supports both inference and training with configurable dropout regularization.

Fields§

§input_size: usize§hidden_size: usize§num_layers: usize§is_training: bool

Implementations§

Source§

impl LSTMNetwork

Source

pub fn new(input_size: usize, hidden_size: usize, num_layers: usize) -> Self

Creates a new multi-layer LSTM network

First layer accepts input_size dimensions, subsequent layers accept hidden_size dimensions from the previous layer.

Source

pub fn with_input_dropout(self, dropout_rate: f64, variational: bool) -> Self

Source

pub fn with_recurrent_dropout( self, dropout_rate: f64, variational: bool, ) -> Self

Source

pub fn with_output_dropout(self, dropout_rate: f64) -> Self

Source

pub fn with_zoneout( self, cell_zoneout_rate: f64, hidden_zoneout_rate: f64, ) -> Self

Source

pub fn with_layer_dropout(self, layer_configs: Vec<LayerDropoutConfig>) -> Self

Source

pub fn train(&mut self)

Source

pub fn eval(&mut self)

Source

pub fn from_cells( cells: Vec<LSTMCell>, input_size: usize, hidden_size: usize, num_layers: usize, ) -> Self

Creates a network from existing cells (used for deserialization)

Source

pub fn get_cells(&self) -> &[LSTMCell]

Get reference to the cells (used for serialization)

Source

pub fn get_cells_mut(&mut self) -> &mut [LSTMCell]

Get mutable reference to the cells (for training mode changes)

Source

pub fn forward( &mut self, input: &Array2<f64>, hx: &Array2<f64>, cx: &Array2<f64>, ) -> (Array2<f64>, Array2<f64>)

Forward pass for inference (no caching)

Source

pub fn forward_with_cache( &mut self, input: &Array2<f64>, hx: &Array2<f64>, cx: &Array2<f64>, ) -> (Array2<f64>, Array2<f64>, LSTMNetworkCache)

Forward pass with caching for training

Source

pub fn backward( &self, dhy: &Array2<f64>, dcy: &Array2<f64>, cache: &LSTMNetworkCache, ) -> (Vec<LSTMCellGradients>, Array2<f64>)

Backward pass through all layers (reverse order)

Implements backpropagation through the multi-layer stack. Returns gradients for each layer and input gradients.

Source

pub fn update_parameters<O: Optimizer>( &mut self, gradients: &[LSTMCellGradients], optimizer: &mut O, )

Update parameters for all layers using computed gradients

Source

pub fn zero_gradients(&self) -> Vec<LSTMCellGradients>

Initialize zero gradients for all layers

Source

pub fn forward_sequence_with_cache( &mut self, sequence: &[Array2<f64>], ) -> (Vec<(Array2<f64>, Array2<f64>)>, Vec<LSTMNetworkCache>)

Process an entire sequence with caching for training

Maintains hidden/cell state across time steps within the sequence. Returns outputs and caches for each time step.

Source

pub fn forward_batch_sequences( &mut self, batch_sequences: &[Vec<Array2<f64>>], ) -> Vec<Vec<(Array2<f64>, Array2<f64>)>>

Process multiple sequences in a batch

§Arguments
  • batch_sequences - Vector of sequences, each sequence is a Vec<Array2> where each Array2 has shape (input_size, 1) for single sequences
§Returns
  • Vector of sequence outputs, where each sequence output is Vec<(Array2, Array2)>
Source

pub fn forward_batch( &mut self, batch_input: &Array2<f64>, batch_hx: &Array2<f64>, batch_cx: &Array2<f64>, ) -> (Array2<f64>, Array2<f64>)

Batch forward pass for single time step across multiple sequences

§Arguments
  • batch_input - Input tensor of shape (input_size, batch_size)
  • batch_hx - Hidden states tensor of shape (hidden_size, batch_size)
  • batch_cx - Cell states tensor of shape (hidden_size, batch_size)
§Returns
  • Tuple of (new_hidden_states, new_cell_states) with same batch dimensions
Source

pub fn forward_batch_with_cache( &mut self, batch_input: &Array2<f64>, batch_hx: &Array2<f64>, batch_cx: &Array2<f64>, ) -> (Array2<f64>, Array2<f64>, LSTMNetworkBatchCache)

Batch forward pass with caching for training

Similar to forward_batch but caches intermediate values needed for backpropagation

Source

pub fn backward_batch( &self, dhy: &Array2<f64>, dcy: &Array2<f64>, cache: &LSTMNetworkBatchCache, ) -> (Vec<LSTMCellGradients>, Array2<f64>)

Batch backward pass for training

Computes gradients for an entire batch simultaneously

Trait Implementations§

Source§

impl Clone for LSTMNetwork

Source§

fn clone(&self) -> LSTMNetwork

Returns a duplicate of the value. Read more
1.0.0 · Source§

fn clone_from(&mut self, source: &Self)

Performs copy-assignment from source. Read more
Source§

impl From<&LSTMNetwork> for SerializableLSTMNetwork

Source§

fn from(network: &LSTMNetwork) -> Self

Converts to this type from the input type.
Source§

impl Into<LSTMNetwork> for SerializableLSTMNetwork

Source§

fn into(self) -> LSTMNetwork

Converts this type into the (usually inferred) input type.
Source§

impl PersistentModel for LSTMNetwork

Source§

fn save<P: AsRef<Path>>( &self, path: P, metadata: ModelMetadata, ) -> Result<(), PersistenceError>

Save model to file (format determined by file extension)
Source§

fn load<P: AsRef<Path>>( path: P, ) -> Result<(Self, ModelMetadata), PersistenceError>

Load model from file (format determined by file extension)

Auto Trait Implementations§

Blanket Implementations§

Source§

impl<T> Any for T
where T: 'static + ?Sized,

Source§

fn type_id(&self) -> TypeId

Gets the TypeId of self. Read more
Source§

impl<T> Borrow<T> for T
where T: ?Sized,

Source§

fn borrow(&self) -> &T

Immutably borrows from an owned value. Read more
Source§

impl<T> BorrowMut<T> for T
where T: ?Sized,

Source§

fn borrow_mut(&mut self) -> &mut T

Mutably borrows from an owned value. Read more
Source§

impl<T> CloneToUninit for T
where T: Clone,

Source§

unsafe fn clone_to_uninit(&self, dest: *mut u8)

🔬This is a nightly-only experimental API. (clone_to_uninit)
Performs copy-assignment from self to dest. Read more
Source§

impl<T> From<T> for T

Source§

fn from(t: T) -> T

Returns the argument unchanged.

Source§

impl<T, U> Into<U> for T
where U: From<T>,

Source§

fn into(self) -> U

Calls U::from(self).

That is, this conversion is whatever the implementation of From<T> for U chooses to do.

Source§

impl<T> ToOwned for T
where T: Clone,

Source§

type Owned = T

The resulting type after obtaining ownership.
Source§

fn to_owned(&self) -> T

Creates owned data from borrowed data, usually by cloning. Read more
Source§

fn clone_into(&self, target: &mut T)

Uses borrowed data to replace owned data, usually by cloning. Read more
Source§

impl<T, U> TryFrom<U> for T
where U: Into<T>,

Source§

type Error = Infallible

The type returned in the event of a conversion error.
Source§

fn try_from(value: U) -> Result<T, <T as TryFrom<U>>::Error>

Performs the conversion.
Source§

impl<T, U> TryInto<U> for T
where U: TryFrom<T>,

Source§

type Error = <U as TryFrom<T>>::Error

The type returned in the event of a conversion error.
Source§

fn try_into(self) -> Result<U, <U as TryFrom<T>>::Error>

Performs the conversion.
Source§

impl<V, T> VZip<V> for T
where V: MultiLane<T>,

Source§

fn vzip(self) -> V