Skip to main content

TransformerEncoderLayer

Struct TransformerEncoderLayer 

Source
pub struct TransformerEncoderLayer {
    pub self_attn: MultiheadAttention,
    pub linear1: Linear,
    pub dropout: f32,
    pub linear2: Linear,
    pub norm1: LayerNorm,
    pub norm2: LayerNorm,
    pub activation: fn(Tensor) -> Tensor,
    pub norm_first: bool,
    pub batch_first: bool,
}
Expand description

A single Transformer Encoder layer, analogous to torch.nn.TransformerEncoderLayer.

This layer implements a standard Transformer encoder block with a multi-head self-attention mechanism followed by a position-wise feedforward network. Layer normalization can be applied either before (“pre-norm”) or after (“post-norm”) the attention and feedforward sub-layers.

Fields§

§self_attn: MultiheadAttention
  • self_attn: The multi-head self-attention module.
§linear1: Linear
  • linear1: The first linear layer of the feedforward network (expansion).
§dropout: f32
  • dropout: Dropout probability applied after attention and feedforward layers.
§linear2: Linear
  • linear2: The second linear layer of the feedforward network (projection back to d_model).
§norm1: LayerNorm
  • norm1: LayerNorm applied after the self-attention block (or before if norm_first is true).
§norm2: LayerNorm
  • norm2: LayerNorm applied after the feedforward block (or before if norm_first is true).
§activation: fn(Tensor) -> Tensor
  • activation: The activation function used in the feedforward network (e.g., ReLU, GELU).
§norm_first: bool
  • norm_first: If true, applies layer normalization before each sub-layer (pre-norm).
§batch_first: bool
  • batch_first: If true, expects input tensors of shape (batch_size, seq_len, d_model).

Implementations§

Source§

impl TransformerEncoderLayer

Source

pub fn new( d_model: u64, nhead: u64, dim_feedforward: u64, dropout: f32, activation: fn(Tensor) -> Tensor, layer_norm_eps: f64, batch_first: bool, norm_first: bool, bias: bool, dtype: DType, ) -> Result<Self, ZyxError>

Constructs a new TransformerEncoderLayer.

§Arguments
  • d_model - The number of expected features in the input (embedding size).
  • nhead - The number of attention heads.
  • dim_feedforward - The dimension of the feedforward network.
  • dropout - Dropout probability applied after attention and feedforward layers.
  • activation - Activation function used in the feedforward network.
  • layer_norm_eps - Epsilon value for numerical stability in layer normalization.
  • batch_first - If true, input/output tensors are expected in (batch, seq, feature) format.
  • norm_first - If true, applies layer normalization before sub-layers (pre-norm).
  • bias - If true, linear layers include bias terms.
  • dtype - The data type of the layer’s parameters and outputs.
§Returns

A Result containing the initialized TransformerEncoderLayer or a ZyxError.

Source

pub fn forward( &self, src: impl Into<Tensor>, src_mask: Option<Tensor>, src_key_padding_mask: Option<Tensor>, ) -> Result<Tensor, ZyxError>

Performs a forward pass of the Transformer encoder layer.

§Arguments
  • src - Input tensor of shape (seq_len, batch_size, d_model) or (batch_size, seq_len, d_model) if batch_first.
  • src_mask - Optional attention mask tensor to prevent attention to certain positions.
  • src_key_padding_mask - Optional mask tensor for padding positions in the input.
§Returns

A Result containing the output tensor after applying self-attention and feedforward blocks.

Trait Implementations§

Source§

impl Debug for TransformerEncoderLayer

Source§

fn fmt(&self, f: &mut Formatter<'_>) -> Result

Formats the value using the given formatter. Read more
Source§

impl<'a> IntoIterator for &'a TransformerEncoderLayer

Source§

type Item = &'a Tensor

The type of the elements being iterated over.
Source§

type IntoIter = IntoIter<&'a Tensor>

Which kind of iterator are we turning this into?
Source§

fn into_iter(self) -> Self::IntoIter

Creates an iterator from a value. Read more
Source§

impl<'a> IntoIterator for &'a mut TransformerEncoderLayer

Source§

type Item = &'a mut Tensor

The type of the elements being iterated over.
Source§

type IntoIter = IntoIter<&'a mut Tensor>

Which kind of iterator are we turning this into?
Source§

fn into_iter(self) -> Self::IntoIter

Creates an iterator from a value. Read more
Source§

impl Module for TransformerEncoderLayer

Source§

fn iter<'a>(&'a self) -> impl Iterator<Item = &'a Tensor>

Iterate over all tensors immutably
Source§

fn iter_mut<'a>(&'a mut self) -> impl Iterator<Item = &'a mut Tensor>

Iterate over all tensors mutably
Source§

fn iter_tensors<'a>(&'a self) -> impl Iterator<Item = (String, &'a Tensor)>

Iterate over tensors without consuming the module
Source§

fn iter_tensors_mut<'a>( &'a mut self, ) -> impl Iterator<Item = (String, &'a mut Tensor)>

From tensors
Source§

fn realize(&self) -> Result<(), ZyxError>

Realize all tensors in the module Read more
Source§

fn set_params(&mut self, params: &mut HashMap<String, Tensor>)

Set parameters, removes them from params, skips parameters that are not found in params.
Source§

fn save(&self, path: impl AsRef<Path>) -> Result<(), ZyxError>

Save tensors or modules to a file determined by file extension. Currently only safetensors is supported format. Read more

Auto Trait Implementations§

Blanket Implementations§

Source§

impl<T> Any for T
where T: 'static + ?Sized,

Source§

fn type_id(&self) -> TypeId

Gets the TypeId of self. Read more
Source§

impl<T> Borrow<T> for T
where T: ?Sized,

Source§

fn borrow(&self) -> &T

Immutably borrows from an owned value. Read more
Source§

impl<T> BorrowMut<T> for T
where T: ?Sized,

Source§

fn borrow_mut(&mut self) -> &mut T

Mutably borrows from an owned value. Read more
Source§

impl<T> From<T> for T

Source§

fn from(t: T) -> T

Returns the argument unchanged.

Source§

impl<T, U> Into<U> for T
where U: From<T>,

Source§

fn into(self) -> U

Calls U::from(self).

That is, this conversion is whatever the implementation of From<T> for U chooses to do.

Source§

impl<T, U> TryFrom<U> for T
where U: Into<T>,

Source§

type Error = Infallible

The type returned in the event of a conversion error.
Source§

fn try_from(value: U) -> Result<T, <T as TryFrom<U>>::Error>

Performs the conversion.
Source§

impl<T, U> TryInto<U> for T
where U: TryFrom<T>,

Source§

type Error = <U as TryFrom<T>>::Error

The type returned in the event of a conversion error.
Source§

fn try_into(self) -> Result<U, <U as TryFrom<T>>::Error>

Performs the conversion.