pub struct LSTMNetwork {
pub input_size: usize,
pub hidden_size: usize,
pub num_layers: usize,
pub is_training: bool,
/* private fields */
}
Expand description
Multi-layer LSTM network for sequence modeling with dropout support
Stacks multiple LSTM cells where the output of layer i becomes the input to layer i+1. Supports both inference and training with configurable dropout regularization.
Fields§
§input_size: usize
§num_layers: usize
§is_training: bool
Implementations§
Source§impl LSTMNetwork
impl LSTMNetwork
Sourcepub fn new(input_size: usize, hidden_size: usize, num_layers: usize) -> Self
pub fn new(input_size: usize, hidden_size: usize, num_layers: usize) -> Self
Creates a new multi-layer LSTM network
First layer accepts input_size
dimensions, subsequent layers
accept hidden_size
dimensions from the previous layer.
pub fn with_input_dropout(self, dropout_rate: f64, variational: bool) -> Self
pub fn with_recurrent_dropout( self, dropout_rate: f64, variational: bool, ) -> Self
pub fn with_output_dropout(self, dropout_rate: f64) -> Self
pub fn with_zoneout( self, cell_zoneout_rate: f64, hidden_zoneout_rate: f64, ) -> Self
pub fn with_layer_dropout(self, layer_configs: Vec<LayerDropoutConfig>) -> Self
pub fn train(&mut self)
pub fn eval(&mut self)
Sourcepub fn from_cells(
cells: Vec<LSTMCell>,
input_size: usize,
hidden_size: usize,
num_layers: usize,
) -> Self
pub fn from_cells( cells: Vec<LSTMCell>, input_size: usize, hidden_size: usize, num_layers: usize, ) -> Self
Creates a network from existing cells (used for deserialization)
Sourcepub fn get_cells_mut(&mut self) -> &mut [LSTMCell]
pub fn get_cells_mut(&mut self) -> &mut [LSTMCell]
Get mutable reference to the cells (for training mode changes)
Sourcepub fn forward(
&mut self,
input: &Array2<f64>,
hx: &Array2<f64>,
cx: &Array2<f64>,
) -> (Array2<f64>, Array2<f64>)
pub fn forward( &mut self, input: &Array2<f64>, hx: &Array2<f64>, cx: &Array2<f64>, ) -> (Array2<f64>, Array2<f64>)
Forward pass for inference (no caching)
Sourcepub fn forward_with_cache(
&mut self,
input: &Array2<f64>,
hx: &Array2<f64>,
cx: &Array2<f64>,
) -> (Array2<f64>, Array2<f64>, LSTMNetworkCache)
pub fn forward_with_cache( &mut self, input: &Array2<f64>, hx: &Array2<f64>, cx: &Array2<f64>, ) -> (Array2<f64>, Array2<f64>, LSTMNetworkCache)
Forward pass with caching for training
Sourcepub fn backward(
&self,
dhy: &Array2<f64>,
dcy: &Array2<f64>,
cache: &LSTMNetworkCache,
) -> (Vec<LSTMCellGradients>, Array2<f64>)
pub fn backward( &self, dhy: &Array2<f64>, dcy: &Array2<f64>, cache: &LSTMNetworkCache, ) -> (Vec<LSTMCellGradients>, Array2<f64>)
Backward pass through all layers (reverse order)
Implements backpropagation through the multi-layer stack. Returns gradients for each layer and input gradients.
Sourcepub fn update_parameters<O: Optimizer>(
&mut self,
gradients: &[LSTMCellGradients],
optimizer: &mut O,
)
pub fn update_parameters<O: Optimizer>( &mut self, gradients: &[LSTMCellGradients], optimizer: &mut O, )
Update parameters for all layers using computed gradients
Sourcepub fn zero_gradients(&self) -> Vec<LSTMCellGradients>
pub fn zero_gradients(&self) -> Vec<LSTMCellGradients>
Initialize zero gradients for all layers
Sourcepub fn forward_sequence_with_cache(
&mut self,
sequence: &[Array2<f64>],
) -> (Vec<(Array2<f64>, Array2<f64>)>, Vec<LSTMNetworkCache>)
pub fn forward_sequence_with_cache( &mut self, sequence: &[Array2<f64>], ) -> (Vec<(Array2<f64>, Array2<f64>)>, Vec<LSTMNetworkCache>)
Process an entire sequence with caching for training
Maintains hidden/cell state across time steps within the sequence. Returns outputs and caches for each time step.
Sourcepub fn forward_batch_sequences(
&mut self,
batch_sequences: &[Vec<Array2<f64>>],
) -> Vec<Vec<(Array2<f64>, Array2<f64>)>>
pub fn forward_batch_sequences( &mut self, batch_sequences: &[Vec<Array2<f64>>], ) -> Vec<Vec<(Array2<f64>, Array2<f64>)>>
Sourcepub fn forward_batch(
&mut self,
batch_input: &Array2<f64>,
batch_hx: &Array2<f64>,
batch_cx: &Array2<f64>,
) -> (Array2<f64>, Array2<f64>)
pub fn forward_batch( &mut self, batch_input: &Array2<f64>, batch_hx: &Array2<f64>, batch_cx: &Array2<f64>, ) -> (Array2<f64>, Array2<f64>)
Batch forward pass for single time step across multiple sequences
§Arguments
batch_input
- Input tensor of shape (input_size, batch_size)batch_hx
- Hidden states tensor of shape (hidden_size, batch_size)batch_cx
- Cell states tensor of shape (hidden_size, batch_size)
§Returns
- Tuple of (new_hidden_states, new_cell_states) with same batch dimensions
Sourcepub fn forward_batch_with_cache(
&mut self,
batch_input: &Array2<f64>,
batch_hx: &Array2<f64>,
batch_cx: &Array2<f64>,
) -> (Array2<f64>, Array2<f64>, LSTMNetworkBatchCache)
pub fn forward_batch_with_cache( &mut self, batch_input: &Array2<f64>, batch_hx: &Array2<f64>, batch_cx: &Array2<f64>, ) -> (Array2<f64>, Array2<f64>, LSTMNetworkBatchCache)
Batch forward pass with caching for training
Similar to forward_batch but caches intermediate values needed for backpropagation
Sourcepub fn backward_batch(
&self,
dhy: &Array2<f64>,
dcy: &Array2<f64>,
cache: &LSTMNetworkBatchCache,
) -> (Vec<LSTMCellGradients>, Array2<f64>)
pub fn backward_batch( &self, dhy: &Array2<f64>, dcy: &Array2<f64>, cache: &LSTMNetworkBatchCache, ) -> (Vec<LSTMCellGradients>, Array2<f64>)
Batch backward pass for training
Computes gradients for an entire batch simultaneously
Trait Implementations§
Source§impl Clone for LSTMNetwork
impl Clone for LSTMNetwork
Source§fn clone(&self) -> LSTMNetwork
fn clone(&self) -> LSTMNetwork
1.0.0 · Source§fn clone_from(&mut self, source: &Self)
fn clone_from(&mut self, source: &Self)
source
. Read more