Layer

Trait Layer 

Source
pub trait Layer<F: Float + Debug + ScalarOperand>: Send + Sync {
Show 17 methods // Required methods fn forward(&self, input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>>; fn backward( &self, input: &Array<F, IxDyn>, grad_output: &Array<F, IxDyn>, ) -> Result<Array<F, IxDyn>>; fn update(&mut self, learningrate: F) -> Result<()>; fn as_any(&self) -> &dyn Any; fn as_any_mut(&mut self) -> &mut dyn Any; // Provided methods fn params(&self) -> Vec<Array<F, IxDyn>> { ... } fn gradients(&self) -> Vec<Array<F, IxDyn>> { ... } fn set_gradients(&mut self, _gradients: &[Array<F, IxDyn>]) -> Result<()> { ... } fn set_params(&mut self, _params: &[Array<F, IxDyn>]) -> Result<()> { ... } fn set_training(&mut self, _training: bool) { ... } fn is_training(&self) -> bool { ... } fn layer_type(&self) -> &str { ... } fn parameter_count(&self) -> usize { ... } fn layer_description(&self) -> String { ... } fn inputshape(&self) -> Option<Vec<usize>> { ... } fn outputshape(&self) -> Option<Vec<usize>> { ... } fn name(&self) -> Option<&str> { ... }
}
Expand description

Base trait for neural network layers

This trait defines the core interface that all neural network layers must implement. It supports forward propagation, backpropagation, parameter management, and training/evaluation mode switching.

Required Methods§

Source

fn forward(&self, input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>>

Forward pass of the layer

Computes the output of the layer given an input tensor.

Source

fn backward( &self, input: &Array<F, IxDyn>, grad_output: &Array<F, IxDyn>, ) -> Result<Array<F, IxDyn>>

Backward pass of the layer to compute gradients

Computes gradients with respect to the layer’s input, which is needed for backpropagation.

Source

fn update(&mut self, learningrate: F) -> Result<()>

Update the layer parameters with the given learning rate

Source

fn as_any(&self) -> &dyn Any

Get the layer as a dyn Any for downcasting

Source

fn as_any_mut(&mut self) -> &mut dyn Any

Get the layer as a mutable dyn Any for downcasting

Provided Methods§

Source

fn params(&self) -> Vec<Array<F, IxDyn>>

Get the parameters of the layer

Source

fn gradients(&self) -> Vec<Array<F, IxDyn>>

Get the gradients of the layer parameters

Source

fn set_gradients(&mut self, _gradients: &[Array<F, IxDyn>]) -> Result<()>

Set the gradients of the layer parameters

Source

fn set_params(&mut self, _params: &[Array<F, IxDyn>]) -> Result<()>

Set the parameters of the layer

Source

fn set_training(&mut self, _training: bool)

Set the layer to training mode (true) or evaluation mode (false)

Source

fn is_training(&self) -> bool

Get the current training mode

Source

fn layer_type(&self) -> &str

Get the type of the layer (e.g., “Dense”, “Conv2D”)

Source

fn parameter_count(&self) -> usize

Get the number of trainable parameters in this layer

Source

fn layer_description(&self) -> String

Get a detailed description of this layer

Source

fn inputshape(&self) -> Option<Vec<usize>>

Get the input shape if known

Source

fn outputshape(&self) -> Option<Vec<usize>>

Get the output shape if known

Source

fn name(&self) -> Option<&str>

Get the name of the layer if set

Implementors§

Source§

impl<F: Float + Debug + Send + Sync + ScalarOperand + Default> Layer<F> for Conv2D<F>

Source§

impl<F: Float + Debug + Send + Sync + ScalarOperand + Default> Layer<F> for MaxPool2D<F>

Source§

impl<F: Float + Debug + ScalarOperand + Send + Sync + 'static> Layer<F> for Dense<F>

Source§

impl<F: Float + Debug + ScalarOperand + Send + Sync + 'static> Layer<F> for Dropout<F>

Source§

impl<F: Float + Debug + ScalarOperand + Send + Sync + 'static> Layer<F> for BatchNorm<F>

Source§

impl<F: Float + Debug + ScalarOperand + Send + Sync + 'static> Layer<F> for LayerNorm<F>

Source§

impl<F: Float + Debug + ScalarOperand + Send + Sync + 'static> Layer<F> for Bidirectional<F>

Source§

impl<F: Float + Debug + ScalarOperand + Send + Sync + 'static> Layer<F> for GRU<F>

Source§

impl<F: Float + Debug + ScalarOperand + Send + Sync + 'static> Layer<F> for LSTM<F>

Source§

impl<F: Float + Debug + ScalarOperand + Send + Sync + 'static> Layer<F> for RNN<F>

Source§

impl<F: Float + Debug + ScalarOperand> Layer<F> for Sequential<F>