pub struct TransformerDecoderLayer { /* private fields */ }Expand description
A single layer of a Transformer decoder.
This layer implements the standard Transformer decoder operations:
- Self-attention on the target sequence.
- Cross-attention using the encoder output (memory).
- Feedforward network with activation function.
- Residual connections and Layer Normalization.
The behavior of the layer can be adjusted using norm_first (pre-norm vs post-norm),
dropout rate, and activation function.
Implementations§
Source§impl TransformerDecoderLayer
impl TransformerDecoderLayer
Sourcepub fn new(
d_model: u64,
nhead: u64,
dim_feedforward: u64,
dropout: f32,
activation: fn(Tensor) -> Tensor,
layer_norm_eps: f64,
batch_first: bool,
norm_first: bool,
bias: bool,
dtype: DType,
) -> Result<Self, ZyxError>
pub fn new( d_model: u64, nhead: u64, dim_feedforward: u64, dropout: f32, activation: fn(Tensor) -> Tensor, layer_norm_eps: f64, batch_first: bool, norm_first: bool, bias: bool, dtype: DType, ) -> Result<Self, ZyxError>
Creates a new TransformerDecoderLayer.
§Arguments
d_model- Dimensionality of input embeddings (number of features per token).nhead- Number of attention heads in self-attention and cross-attention.dim_feedforward- Hidden dimension of the feedforward network.dropout- Dropout probability applied after attention and feedforward layers.activation- Activation function applied after the feedforward network (e.g., ReLU).layer_norm_eps- Small epsilon value for numerical stability in layer normalization.batch_first- If true, input tensors have shape[batch, seq, feature]. Otherwise[seq, batch, feature].norm_first- Whether to apply layer normalization before sub-layers (pre-norm) or after (post-norm).bias- Whether to include bias terms in linear and attention layers.dtype- Data type of tensors (e.g.,DType::F32,DType::F64).
§Returns
Returns a Result containing the new TransformerDecoderLayer or a ZyxError if initialization fails.
Sourcepub fn forward(
&self,
tgt: &Tensor,
memory: &Tensor,
tgt_mask: Option<impl Into<Tensor>>,
memory_mask: Option<impl Into<Tensor>>,
tgt_key_padding_mask: Option<impl Into<Tensor>>,
memory_key_padding_mask: Option<impl Into<Tensor>>,
tgt_is_causal: bool,
memory_is_causal: bool,
) -> Result<Tensor, ZyxError>
pub fn forward( &self, tgt: &Tensor, memory: &Tensor, tgt_mask: Option<impl Into<Tensor>>, memory_mask: Option<impl Into<Tensor>>, tgt_key_padding_mask: Option<impl Into<Tensor>>, memory_key_padding_mask: Option<impl Into<Tensor>>, tgt_is_causal: bool, memory_is_causal: bool, ) -> Result<Tensor, ZyxError>
Performs a forward pass through the decoder layer.
§Arguments
tgt- Target sequence tensor (decoder input).memory- Memory tensor from the encoder (encoder output).tgt_mask- Optional mask for self-attention on the target sequence.memory_mask- Optional mask for cross-attention on the memory sequence.tgt_key_padding_mask- Optional padding mask for target tokens.memory_key_padding_mask- Optional padding mask for memory tokens.tgt_is_causal- Whether to apply causal masking to target self-attention (autoregressive decoding).memory_is_causal- Whether to apply causal masking in cross-attention.
§Returns
Returns a Result containing the output tensor of the decoder layer or a ZyxError.
§Behavior
- Applies layer normalization if
norm_firstis true. - Applies self-attention on the target sequence.
- Applies residual connection and dropout.
- Applies cross-attention with the encoder memory.
- Applies residual connection and dropout.
- Passes through feedforward network with activation.
- Applies final residual connection and layer normalization.
Trait Implementations§
Source§impl Debug for TransformerDecoderLayer
impl Debug for TransformerDecoderLayer
Source§impl<'a> IntoIterator for &'a TransformerDecoderLayer
impl<'a> IntoIterator for &'a TransformerDecoderLayer
Source§impl<'a> IntoIterator for &'a mut TransformerDecoderLayer
impl<'a> IntoIterator for &'a mut TransformerDecoderLayer
Source§impl Module for TransformerDecoderLayer
impl Module for TransformerDecoderLayer
Source§fn iter_mut<'a>(&'a mut self) -> impl Iterator<Item = &'a mut Tensor>
fn iter_mut<'a>(&'a mut self) -> impl Iterator<Item = &'a mut Tensor>
Iterate over all tensors mutably
Source§fn iter_tensors<'a>(&'a self) -> impl Iterator<Item = (String, &'a Tensor)>
fn iter_tensors<'a>(&'a self) -> impl Iterator<Item = (String, &'a Tensor)>
Iterate over tensors without consuming the module
Source§fn iter_tensors_mut<'a>(
&'a mut self,
) -> impl Iterator<Item = (String, &'a mut Tensor)>
fn iter_tensors_mut<'a>( &'a mut self, ) -> impl Iterator<Item = (String, &'a mut Tensor)>
From tensors
Auto Trait Implementations§
impl Freeze for TransformerDecoderLayer
impl RefUnwindSafe for TransformerDecoderLayer
impl Send for TransformerDecoderLayer
impl Sync for TransformerDecoderLayer
impl Unpin for TransformerDecoderLayer
impl UnsafeUnpin for TransformerDecoderLayer
impl UnwindSafe for TransformerDecoderLayer
Blanket Implementations§
Source§impl<T> BorrowMut<T> for Twhere
T: ?Sized,
impl<T> BorrowMut<T> for Twhere
T: ?Sized,
Source§fn borrow_mut(&mut self) -> &mut T
fn borrow_mut(&mut self) -> &mut T
Mutably borrows from an owned value. Read more