pub struct LSTMCell {
pub w_ih: Array2<f64>,
pub w_hh: Array2<f64>,
pub b_ih: Array2<f64>,
pub b_hh: Array2<f64>,
pub hidden_size: usize,
pub input_dropout: Option<Dropout>,
pub recurrent_dropout: Option<Dropout>,
pub output_dropout: Option<Dropout>,
pub zoneout: Option<Zoneout>,
pub is_training: bool,
}
Expand description
LSTM cell with trainable parameters and dropout support
Fields§
§w_ih: Array2<f64>
§w_hh: Array2<f64>
§b_ih: Array2<f64>
§b_hh: Array2<f64>
§input_dropout: Option<Dropout>
§recurrent_dropout: Option<Dropout>
§output_dropout: Option<Dropout>
§zoneout: Option<Zoneout>
§is_training: bool
Implementations§
Source§impl LSTMCell
impl LSTMCell
Sourcepub fn new(input_size: usize, hidden_size: usize) -> Self
pub fn new(input_size: usize, hidden_size: usize) -> Self
Creates new LSTM cell with Xavier-uniform weight initialization
pub fn with_input_dropout(self, dropout_rate: f64, variational: bool) -> Self
pub fn with_recurrent_dropout( self, dropout_rate: f64, variational: bool, ) -> Self
pub fn with_output_dropout(self, dropout_rate: f64) -> Self
pub fn with_zoneout( self, cell_zoneout_rate: f64, hidden_zoneout_rate: f64, ) -> Self
pub fn train(&mut self)
pub fn eval(&mut self)
pub fn forward( &mut self, input: &Array2<f64>, hx: &Array2<f64>, cx: &Array2<f64>, ) -> (Array2<f64>, Array2<f64>)
pub fn forward_with_cache( &mut self, input: &Array2<f64>, hx: &Array2<f64>, cx: &Array2<f64>, ) -> (Array2<f64>, Array2<f64>, LSTMCellCache)
Sourcepub fn forward_batch(
&mut self,
input: &Array2<f64>,
hx: &Array2<f64>,
cx: &Array2<f64>,
) -> (Array2<f64>, Array2<f64>)
pub fn forward_batch( &mut self, input: &Array2<f64>, hx: &Array2<f64>, cx: &Array2<f64>, ) -> (Array2<f64>, Array2<f64>)
Batch forward pass for multiple sequences simultaneously
§Arguments
input
- Input tensor of shape (input_size, batch_size)hx
- Hidden state tensor of shape (hidden_size, batch_size)cx
- Cell state tensor of shape (hidden_size, batch_size)
§Returns
- Tuple of (new_hidden_state, new_cell_state) with same batch dimensions
Sourcepub fn forward_batch_with_cache(
&mut self,
input: &Array2<f64>,
hx: &Array2<f64>,
cx: &Array2<f64>,
) -> (Array2<f64>, Array2<f64>, LSTMCellBatchCache)
pub fn forward_batch_with_cache( &mut self, input: &Array2<f64>, hx: &Array2<f64>, cx: &Array2<f64>, ) -> (Array2<f64>, Array2<f64>, LSTMCellBatchCache)
Batch forward pass with caching for training
Similar to forward_batch but caches intermediate values needed for backpropagation
Sourcepub fn backward(
&self,
dhy: &Array2<f64>,
dcy: &Array2<f64>,
cache: &LSTMCellCache,
) -> (LSTMCellGradients, Array2<f64>, Array2<f64>, Array2<f64>)
pub fn backward( &self, dhy: &Array2<f64>, dcy: &Array2<f64>, cache: &LSTMCellCache, ) -> (LSTMCellGradients, Array2<f64>, Array2<f64>, Array2<f64>)
Backward pass implementing LSTM gradient computation with dropout
Returns (parameter_gradients, input_gradient, hidden_gradient, cell_gradient)
Sourcepub fn backward_batch(
&self,
dhy: &Array2<f64>,
dcy: &Array2<f64>,
cache: &LSTMCellBatchCache,
) -> (LSTMCellGradients, Array2<f64>, Array2<f64>, Array2<f64>)
pub fn backward_batch( &self, dhy: &Array2<f64>, dcy: &Array2<f64>, cache: &LSTMCellBatchCache, ) -> (LSTMCellGradients, Array2<f64>, Array2<f64>, Array2<f64>)
Batch backward pass for training with multiple sequences
Computes gradients for an entire batch simultaneously
Sourcepub fn zero_gradients(&self) -> LSTMCellGradients
pub fn zero_gradients(&self) -> LSTMCellGradients
Initialize zero gradients for accumulation
Sourcepub fn update_parameters<O: Optimizer>(
&mut self,
gradients: &LSTMCellGradients,
optimizer: &mut O,
prefix: &str,
)
pub fn update_parameters<O: Optimizer>( &mut self, gradients: &LSTMCellGradients, optimizer: &mut O, prefix: &str, )
Apply gradients using the provided optimizer
Trait Implementations§
Auto Trait Implementations§
impl Freeze for LSTMCell
impl RefUnwindSafe for LSTMCell
impl Send for LSTMCell
impl Sync for LSTMCell
impl Unpin for LSTMCell
impl UnwindSafe for LSTMCell
Blanket Implementations§
Source§impl<T> BorrowMut<T> for Twhere
T: ?Sized,
impl<T> BorrowMut<T> for Twhere
T: ?Sized,
Source§fn borrow_mut(&mut self) -> &mut T
fn borrow_mut(&mut self) -> &mut T
Mutably borrows from an owned value. Read more