pub struct TransformerEncoderLayer {
pub self_attn: MultiheadAttention,
pub linear1: Linear,
pub dropout: f32,
pub linear2: Linear,
pub norm1: LayerNorm,
pub norm2: LayerNorm,
pub activation: fn(Tensor) -> Tensor,
pub norm_first: bool,
pub batch_first: bool,
}Expand description
A single Transformer Encoder layer, analogous to torch.nn.TransformerEncoderLayer.
This layer implements a standard Transformer encoder block with a multi-head self-attention mechanism followed by a position-wise feedforward network. Layer normalization can be applied either before (“pre-norm”) or after (“post-norm”) the attention and feedforward sub-layers.
Fields§
§self_attn: MultiheadAttentionself_attn: The multi-head self-attention module.
linear1: Linearlinear1: The first linear layer of the feedforward network (expansion).
dropout: f32dropout: Dropout probability applied after attention and feedforward layers.
linear2: Linearlinear2: The second linear layer of the feedforward network (projection back tod_model).
norm1: LayerNormnorm1: LayerNorm applied after the self-attention block (or before ifnorm_firstis true).
norm2: LayerNormnorm2: LayerNorm applied after the feedforward block (or before ifnorm_firstis true).
activation: fn(Tensor) -> Tensoractivation: The activation function used in the feedforward network (e.g., ReLU, GELU).
norm_first: boolnorm_first: Iftrue, applies layer normalization before each sub-layer (pre-norm).
batch_first: boolbatch_first: Iftrue, expects input tensors of shape(batch_size, seq_len, d_model).
Implementations§
Source§impl TransformerEncoderLayer
impl TransformerEncoderLayer
Sourcepub fn new(
d_model: u64,
nhead: u64,
dim_feedforward: u64,
dropout: f32,
activation: fn(Tensor) -> Tensor,
layer_norm_eps: f64,
batch_first: bool,
norm_first: bool,
bias: bool,
dtype: DType,
) -> Result<Self, ZyxError>
pub fn new( d_model: u64, nhead: u64, dim_feedforward: u64, dropout: f32, activation: fn(Tensor) -> Tensor, layer_norm_eps: f64, batch_first: bool, norm_first: bool, bias: bool, dtype: DType, ) -> Result<Self, ZyxError>
Constructs a new TransformerEncoderLayer.
§Arguments
d_model- The number of expected features in the input (embedding size).nhead- The number of attention heads.dim_feedforward- The dimension of the feedforward network.dropout- Dropout probability applied after attention and feedforward layers.activation- Activation function used in the feedforward network.layer_norm_eps- Epsilon value for numerical stability in layer normalization.batch_first- Iftrue, input/output tensors are expected in(batch, seq, feature)format.norm_first- Iftrue, applies layer normalization before sub-layers (pre-norm).bias- Iftrue, linear layers include bias terms.dtype- The data type of the layer’s parameters and outputs.
§Returns
A Result containing the initialized TransformerEncoderLayer or a ZyxError.
Sourcepub fn forward(
&self,
src: impl Into<Tensor>,
src_mask: Option<Tensor>,
src_key_padding_mask: Option<Tensor>,
) -> Result<Tensor, ZyxError>
pub fn forward( &self, src: impl Into<Tensor>, src_mask: Option<Tensor>, src_key_padding_mask: Option<Tensor>, ) -> Result<Tensor, ZyxError>
Performs a forward pass of the Transformer encoder layer.
§Arguments
src- Input tensor of shape(seq_len, batch_size, d_model)or(batch_size, seq_len, d_model)ifbatch_first.src_mask- Optional attention mask tensor to prevent attention to certain positions.src_key_padding_mask- Optional mask tensor for padding positions in the input.
§Returns
A Result containing the output tensor after applying self-attention and feedforward blocks.
Trait Implementations§
Source§impl Debug for TransformerEncoderLayer
impl Debug for TransformerEncoderLayer
Source§impl<'a> IntoIterator for &'a TransformerEncoderLayer
impl<'a> IntoIterator for &'a TransformerEncoderLayer
Source§impl<'a> IntoIterator for &'a mut TransformerEncoderLayer
impl<'a> IntoIterator for &'a mut TransformerEncoderLayer
Source§impl Module for TransformerEncoderLayer
impl Module for TransformerEncoderLayer
Source§fn iter_mut<'a>(&'a mut self) -> impl Iterator<Item = &'a mut Tensor>
fn iter_mut<'a>(&'a mut self) -> impl Iterator<Item = &'a mut Tensor>
Iterate over all tensors mutably
Source§fn iter_tensors<'a>(&'a self) -> impl Iterator<Item = (String, &'a Tensor)>
fn iter_tensors<'a>(&'a self) -> impl Iterator<Item = (String, &'a Tensor)>
Iterate over tensors without consuming the module
Source§fn iter_tensors_mut<'a>(
&'a mut self,
) -> impl Iterator<Item = (String, &'a mut Tensor)>
fn iter_tensors_mut<'a>( &'a mut self, ) -> impl Iterator<Item = (String, &'a mut Tensor)>
From tensors
Auto Trait Implementations§
impl Freeze for TransformerEncoderLayer
impl RefUnwindSafe for TransformerEncoderLayer
impl Send for TransformerEncoderLayer
impl Sync for TransformerEncoderLayer
impl Unpin for TransformerEncoderLayer
impl UnsafeUnpin for TransformerEncoderLayer
impl UnwindSafe for TransformerEncoderLayer
Blanket Implementations§
Source§impl<T> BorrowMut<T> for Twhere
T: ?Sized,
impl<T> BorrowMut<T> for Twhere
T: ?Sized,
Source§fn borrow_mut(&mut self) -> &mut T
fn borrow_mut(&mut self) -> &mut T
Mutably borrows from an owned value. Read more