pub struct LSTMCell {
pub w_ih: Array2<f64>,
pub w_hh: Array2<f64>,
pub b_ih: Array2<f64>,
pub b_hh: Array2<f64>,
pub hidden_size: usize,
pub input_dropout: Option<Dropout>,
pub recurrent_dropout: Option<Dropout>,
pub output_dropout: Option<Dropout>,
pub zoneout: Option<Zoneout>,
pub is_training: bool,
}
Expand description
LSTM cell with trainable parameters and dropout support
Fields§
§w_ih: Array2<f64>
§w_hh: Array2<f64>
§b_ih: Array2<f64>
§b_hh: Array2<f64>
§input_dropout: Option<Dropout>
§recurrent_dropout: Option<Dropout>
§output_dropout: Option<Dropout>
§zoneout: Option<Zoneout>
§is_training: bool
Implementations§
Source§impl LSTMCell
impl LSTMCell
Sourcepub fn new(input_size: usize, hidden_size: usize) -> Self
pub fn new(input_size: usize, hidden_size: usize) -> Self
Creates new LSTM cell with Xavier-uniform weight initialization
pub fn with_input_dropout(self, dropout_rate: f64, variational: bool) -> Self
pub fn with_recurrent_dropout( self, dropout_rate: f64, variational: bool, ) -> Self
pub fn with_output_dropout(self, dropout_rate: f64) -> Self
pub fn with_zoneout( self, cell_zoneout_rate: f64, hidden_zoneout_rate: f64, ) -> Self
pub fn train(&mut self)
pub fn eval(&mut self)
pub fn forward( &mut self, input: &Array2<f64>, hx: &Array2<f64>, cx: &Array2<f64>, ) -> (Array2<f64>, Array2<f64>)
pub fn forward_with_cache( &mut self, input: &Array2<f64>, hx: &Array2<f64>, cx: &Array2<f64>, ) -> (Array2<f64>, Array2<f64>, LSTMCellCache)
Sourcepub fn backward(
&self,
dhy: &Array2<f64>,
dcy: &Array2<f64>,
cache: &LSTMCellCache,
) -> (LSTMCellGradients, Array2<f64>, Array2<f64>, Array2<f64>)
pub fn backward( &self, dhy: &Array2<f64>, dcy: &Array2<f64>, cache: &LSTMCellCache, ) -> (LSTMCellGradients, Array2<f64>, Array2<f64>, Array2<f64>)
Backward pass implementing LSTM gradient computation with dropout
Returns (parameter_gradients, input_gradient, hidden_gradient, cell_gradient)
Sourcepub fn zero_gradients(&self) -> LSTMCellGradients
pub fn zero_gradients(&self) -> LSTMCellGradients
Initialize zero gradients for accumulation
Sourcepub fn update_parameters<O: Optimizer>(
&mut self,
gradients: &LSTMCellGradients,
optimizer: &mut O,
prefix: &str,
)
pub fn update_parameters<O: Optimizer>( &mut self, gradients: &LSTMCellGradients, optimizer: &mut O, prefix: &str, )
Apply gradients using the provided optimizer
Trait Implementations§
Auto Trait Implementations§
impl Freeze for LSTMCell
impl RefUnwindSafe for LSTMCell
impl Send for LSTMCell
impl Sync for LSTMCell
impl Unpin for LSTMCell
impl UnwindSafe for LSTMCell
Blanket Implementations§
Source§impl<T> BorrowMut<T> for Twhere
T: ?Sized,
impl<T> BorrowMut<T> for Twhere
T: ?Sized,
Source§fn borrow_mut(&mut self) -> &mut T
fn borrow_mut(&mut self) -> &mut T
Mutably borrows from an owned value. Read more