pub struct GRUNetwork {
pub input_size: usize,
pub hidden_size: usize,
pub num_layers: usize,
pub is_training: bool,
/* private fields */
}Expand description
Multi-layer GRU network for sequence modeling
Fields§
§input_size: usize§num_layers: usize§is_training: boolImplementations§
Source§impl GRUNetwork
impl GRUNetwork
Sourcepub fn new(input_size: usize, hidden_size: usize, num_layers: usize) -> Self
pub fn new(input_size: usize, hidden_size: usize, num_layers: usize) -> Self
Creates a new multi-layer GRU network
Sourcepub fn with_input_dropout(self, dropout_rate: f64, variational: bool) -> Self
pub fn with_input_dropout(self, dropout_rate: f64, variational: bool) -> Self
Apply uniform dropout across all layers
pub fn with_recurrent_dropout( self, dropout_rate: f64, variational: bool, ) -> Self
pub fn with_output_dropout(self, dropout_rate: f64) -> Self
Sourcepub fn with_layer_dropout(self, configs: Vec<LayerDropoutConfig>) -> Self
pub fn with_layer_dropout(self, configs: Vec<LayerDropoutConfig>) -> Self
Apply layer-specific dropout configuration
pub fn train(&mut self)
pub fn eval(&mut self)
Sourcepub fn forward(
&mut self,
input: &Array2<f64>,
hx: &[Array2<f64>],
) -> Vec<Array2<f64>> ⓘ
pub fn forward( &mut self, input: &Array2<f64>, hx: &[Array2<f64>], ) -> Vec<Array2<f64>> ⓘ
Forward pass for a single time step
Sourcepub fn forward_sequence_with_cache(
&mut self,
sequence: &[Array2<f64>],
) -> (Vec<(Array2<f64>, Vec<Array2<f64>>)>, Vec<GRUNetworkCache>)
pub fn forward_sequence_with_cache( &mut self, sequence: &[Array2<f64>], ) -> (Vec<(Array2<f64>, Vec<Array2<f64>>)>, Vec<GRUNetworkCache>)
Forward pass for a sequence with caching for training
Sourcepub fn backward(
&self,
dhy: &Array2<f64>,
cache: &GRUNetworkCache,
) -> (Vec<GRUCellGradients>, Array2<f64>)
pub fn backward( &self, dhy: &Array2<f64>, cache: &GRUNetworkCache, ) -> (Vec<GRUCellGradients>, Array2<f64>)
Backward pass for training
Sourcepub fn update_parameters<O: Optimizer>(
&mut self,
gradients: &[GRUCellGradients],
optimizer: &mut O,
)
pub fn update_parameters<O: Optimizer>( &mut self, gradients: &[GRUCellGradients], optimizer: &mut O, )
Update parameters using optimizer
Sourcepub fn zero_gradients(&self) -> Vec<GRUCellGradients>
pub fn zero_gradients(&self) -> Vec<GRUCellGradients>
Initialize zero gradients for all layers
Sourcepub fn get_cells_mut(&mut self) -> &mut [GRUCell]
pub fn get_cells_mut(&mut self) -> &mut [GRUCell]
Get mutable references to cells
Trait Implementations§
Source§impl Clone for GRUNetwork
impl Clone for GRUNetwork
Source§fn clone(&self) -> GRUNetwork
fn clone(&self) -> GRUNetwork
Returns a duplicate of the value. Read more
1.0.0 · Source§fn clone_from(&mut self, source: &Self)
fn clone_from(&mut self, source: &Self)
Performs copy-assignment from
source. Read moreAuto Trait Implementations§
impl Freeze for GRUNetwork
impl RefUnwindSafe for GRUNetwork
impl Send for GRUNetwork
impl Sync for GRUNetwork
impl Unpin for GRUNetwork
impl UnwindSafe for GRUNetwork
Blanket Implementations§
Source§impl<T> BorrowMut<T> for Twhere
T: ?Sized,
impl<T> BorrowMut<T> for Twhere
T: ?Sized,
Source§fn borrow_mut(&mut self) -> &mut T
fn borrow_mut(&mut self) -> &mut T
Mutably borrows from an owned value. Read more