Skip to main content

TransformerDecoderLayer

Struct TransformerDecoderLayer 

Source
pub struct TransformerDecoderLayer { /* private fields */ }
Expand description

A single layer of a Transformer decoder.

This layer implements the standard Transformer decoder operations:

  1. Self-attention on the target sequence.
  2. Cross-attention using the encoder output (memory).
  3. Feedforward network with activation function.
  4. Residual connections and Layer Normalization.

The behavior of the layer can be adjusted using norm_first (pre-norm vs post-norm), dropout rate, and activation function.

Implementations§

Source§

impl TransformerDecoderLayer

Source

pub fn new( d_model: u64, nhead: u64, dim_feedforward: u64, dropout: f32, activation: fn(Tensor) -> Tensor, layer_norm_eps: f64, batch_first: bool, norm_first: bool, bias: bool, dtype: DType, ) -> Result<Self, ZyxError>

Creates a new TransformerDecoderLayer.

§Arguments
  • d_model - Dimensionality of input embeddings (number of features per token).
  • nhead - Number of attention heads in self-attention and cross-attention.
  • dim_feedforward - Hidden dimension of the feedforward network.
  • dropout - Dropout probability applied after attention and feedforward layers.
  • activation - Activation function applied after the feedforward network (e.g., ReLU).
  • layer_norm_eps - Small epsilon value for numerical stability in layer normalization.
  • batch_first - If true, input tensors have shape [batch, seq, feature]. Otherwise [seq, batch, feature].
  • norm_first - Whether to apply layer normalization before sub-layers (pre-norm) or after (post-norm).
  • bias - Whether to include bias terms in linear and attention layers.
  • dtype - Data type of tensors (e.g., DType::F32, DType::F64).
§Returns

Returns a Result containing the new TransformerDecoderLayer or a ZyxError if initialization fails.

Source

pub fn forward( &self, tgt: &Tensor, memory: &Tensor, tgt_mask: Option<impl Into<Tensor>>, memory_mask: Option<impl Into<Tensor>>, tgt_key_padding_mask: Option<impl Into<Tensor>>, memory_key_padding_mask: Option<impl Into<Tensor>>, tgt_is_causal: bool, memory_is_causal: bool, ) -> Result<Tensor, ZyxError>

Performs a forward pass through the decoder layer.

§Arguments
  • tgt - Target sequence tensor (decoder input).
  • memory - Memory tensor from the encoder (encoder output).
  • tgt_mask - Optional mask for self-attention on the target sequence.
  • memory_mask - Optional mask for cross-attention on the memory sequence.
  • tgt_key_padding_mask - Optional padding mask for target tokens.
  • memory_key_padding_mask - Optional padding mask for memory tokens.
  • tgt_is_causal - Whether to apply causal masking to target self-attention (autoregressive decoding).
  • memory_is_causal - Whether to apply causal masking in cross-attention.
§Returns

Returns a Result containing the output tensor of the decoder layer or a ZyxError.

§Behavior
  1. Applies layer normalization if norm_first is true.
  2. Applies self-attention on the target sequence.
  3. Applies residual connection and dropout.
  4. Applies cross-attention with the encoder memory.
  5. Applies residual connection and dropout.
  6. Passes through feedforward network with activation.
  7. Applies final residual connection and layer normalization.

Trait Implementations§

Source§

impl Debug for TransformerDecoderLayer

Source§

fn fmt(&self, f: &mut Formatter<'_>) -> Result

Formats the value using the given formatter. Read more
Source§

impl<'a> IntoIterator for &'a TransformerDecoderLayer

Source§

type Item = &'a Tensor

The type of the elements being iterated over.
Source§

type IntoIter = IntoIter<&'a Tensor>

Which kind of iterator are we turning this into?
Source§

fn into_iter(self) -> Self::IntoIter

Creates an iterator from a value. Read more
Source§

impl<'a> IntoIterator for &'a mut TransformerDecoderLayer

Source§

type Item = &'a mut Tensor

The type of the elements being iterated over.
Source§

type IntoIter = IntoIter<&'a mut Tensor>

Which kind of iterator are we turning this into?
Source§

fn into_iter(self) -> Self::IntoIter

Creates an iterator from a value. Read more
Source§

impl Module for TransformerDecoderLayer

Source§

fn iter<'a>(&'a self) -> impl Iterator<Item = &'a Tensor>

Iterate over all tensors immutably
Source§

fn iter_mut<'a>(&'a mut self) -> impl Iterator<Item = &'a mut Tensor>

Iterate over all tensors mutably
Source§

fn iter_tensors<'a>(&'a self) -> impl Iterator<Item = (String, &'a Tensor)>

Iterate over tensors without consuming the module
Source§

fn iter_tensors_mut<'a>( &'a mut self, ) -> impl Iterator<Item = (String, &'a mut Tensor)>

From tensors
Source§

fn realize(&self) -> Result<(), ZyxError>

Realize all tensors in the module Read more
Source§

fn set_params(&mut self, params: &mut HashMap<String, Tensor>)

Set parameters, removes them from params, skips parameters that are not found in params.
Source§

fn save(&self, path: impl AsRef<Path>) -> Result<(), ZyxError>

Save tensors or modules to a file determined by file extension. Currently only safetensors is supported format. Read more

Auto Trait Implementations§

Blanket Implementations§

Source§

impl<T> Any for T
where T: 'static + ?Sized,

Source§

fn type_id(&self) -> TypeId

Gets the TypeId of self. Read more
Source§

impl<T> Borrow<T> for T
where T: ?Sized,

Source§

fn borrow(&self) -> &T

Immutably borrows from an owned value. Read more
Source§

impl<T> BorrowMut<T> for T
where T: ?Sized,

Source§

fn borrow_mut(&mut self) -> &mut T

Mutably borrows from an owned value. Read more
Source§

impl<T> From<T> for T

Source§

fn from(t: T) -> T

Returns the argument unchanged.

Source§

impl<T, U> Into<U> for T
where U: From<T>,

Source§

fn into(self) -> U

Calls U::from(self).

That is, this conversion is whatever the implementation of From<T> for U chooses to do.

Source§

impl<T, U> TryFrom<U> for T
where U: Into<T>,

Source§

type Error = Infallible

The type returned in the event of a conversion error.
Source§

fn try_from(value: U) -> Result<T, <T as TryFrom<U>>::Error>

Performs the conversion.
Source§

impl<T, U> TryInto<U> for T
where U: TryFrom<T>,

Source§

type Error = <U as TryFrom<T>>::Error

The type returned in the event of a conversion error.
Source§

fn try_into(self) -> Result<U, <U as TryFrom<T>>::Error>

Performs the conversion.