Skip to main content

zyx_nn/
transformer_decoder_layer.rs

1// Copyright (C) 2025 zk4x
2// SPDX-License-Identifier: LGPL-3.0-only
3
4use crate::{LayerNorm, Linear, MultiheadAttention};
5use zyx::{DType, Tensor, ZyxError};
6use zyx_derive::Module;
7
8/// A single layer of a Transformer decoder.
9///
10/// This layer implements the standard Transformer decoder operations:
11/// 1. **Self-attention** on the target sequence.
12/// 2. **Cross-attention** using the encoder output (memory).
13/// 3. **Feedforward network** with activation function.
14/// 4. **Residual connections** and **Layer Normalization**.
15///
16/// The behavior of the layer can be adjusted using `norm_first` (pre-norm vs post-norm),
17/// dropout rate, and activation function.
18#[derive(Debug, Module)]
19#[cfg_attr(feature = "py", pyo3::pyclass)]
20pub struct TransformerDecoderLayer {
21    self_attention: MultiheadAttention,
22    cross_attention: MultiheadAttention,
23    feedforward: Linear,
24    layer_norm_1: LayerNorm,
25    layer_norm_2: LayerNorm,
26    dropout_rate: f32,                // Dropout rate passed as a parameter
27    norm_first: bool,                 // Whether to apply norm before layers or after
28    activation: fn(Tensor) -> Tensor, // Activation function
29}
30
31impl TransformerDecoderLayer {
32    /// Creates a new `TransformerDecoderLayer`.
33    ///
34    /// # Arguments
35    ///
36    /// * `d_model` - Dimensionality of input embeddings (number of features per token).
37    /// * `nhead` - Number of attention heads in self-attention and cross-attention.
38    /// * `dim_feedforward` - Hidden dimension of the feedforward network.
39    /// * `dropout` - Dropout probability applied after attention and feedforward layers.
40    /// * `activation` - Activation function applied after the feedforward network (e.g., ReLU).
41    /// * `layer_norm_eps` - Small epsilon value for numerical stability in layer normalization.
42    /// * `batch_first` - If true, input tensors have shape `[batch, seq, feature]`. Otherwise `[seq, batch, feature]`.
43    /// * `norm_first` - Whether to apply layer normalization before sub-layers (pre-norm) or after (post-norm).
44    /// * `bias` - Whether to include bias terms in linear and attention layers.
45    /// * `dtype` - Data type of tensors (e.g., `DType::F32`, `DType::F64`).
46    ///
47    /// # Returns
48    ///
49    /// Returns a `Result` containing the new `TransformerDecoderLayer` or a `ZyxError` if initialization fails.
50    pub fn new(
51        d_model: u64,                   // embed_dim
52        nhead: u64,                     // num_heads
53        dim_feedforward: u64,           // dim_feedforward
54        dropout: f32,                     // dropout rate
55        activation: fn(Tensor) -> Tensor, // activation function
56        layer_norm_eps: f64,              // layer_norm_eps
57        batch_first: bool,                // batch_first
58        norm_first: bool,                 // norm_first
59        bias: bool,                       // use biases in layers
60        dtype: DType,                     // tensor data type (e.g., f32, f64)
61    ) -> Result<Self, ZyxError> {
62        // Create self-attention and cross-attention layers
63        let self_attention = MultiheadAttention::new(
64            d_model,
65            nhead,
66            dropout,
67            bias,
68            false,
69            false,
70            None,
71            None,
72            batch_first,
73            dtype,
74        )?;
75
76        let cross_attention = MultiheadAttention::new(
77            d_model,
78            nhead,
79            dropout,
80            bias,
81            false,
82            false,
83            None,
84            None,
85            batch_first,
86            dtype,
87        )?;
88
89        // Create feedforward layers (two Linear layers with ReLU activation)
90        let feedforward = Linear::new(d_model, dim_feedforward, bias, dtype)?;
91
92        // LayerNorms (First for attention, second for feedforward)
93        let layer_norm_1 = LayerNorm::new(d_model, layer_norm_eps, true, bias, dtype)?;
94        let layer_norm_2 = LayerNorm::new(d_model, layer_norm_eps, true, bias, dtype)?;
95
96        // Return the complete TransformerDecoderLayer
97        Ok(TransformerDecoderLayer {
98            self_attention,
99            cross_attention,
100            feedforward,
101            layer_norm_1,
102            layer_norm_2,
103            dropout_rate: dropout,
104            norm_first,
105            activation,
106        })
107    }
108
109    /// Performs a forward pass through the decoder layer.
110    ///
111    /// # Arguments
112    ///
113    /// * `tgt` - Target sequence tensor (decoder input).
114    /// * `memory` - Memory tensor from the encoder (encoder output).
115    /// * `tgt_mask` - Optional mask for self-attention on the target sequence.
116    /// * `memory_mask` - Optional mask for cross-attention on the memory sequence.
117    /// * `tgt_key_padding_mask` - Optional padding mask for target tokens.
118    /// * `memory_key_padding_mask` - Optional padding mask for memory tokens.
119    /// * `tgt_is_causal` - Whether to apply causal masking to target self-attention (autoregressive decoding).
120    /// * `memory_is_causal` - Whether to apply causal masking in cross-attention.
121    ///
122    /// # Returns
123    ///
124    /// Returns a `Result` containing the output tensor of the decoder layer or a `ZyxError`.
125    ///
126    /// # Behavior
127    ///
128    /// 1. Applies layer normalization if `norm_first` is true.
129    /// 2. Applies self-attention on the target sequence.
130    /// 3. Applies residual connection and dropout.
131    /// 4. Applies cross-attention with the encoder memory.
132    /// 5. Applies residual connection and dropout.
133    /// 6. Passes through feedforward network with activation.
134    /// 7. Applies final residual connection and layer normalization.
135    pub fn forward(
136        &self,
137        tgt: &Tensor,                           // Target sequence (input to the decoder)
138        memory: &Tensor,                        // Memory sequence (encoder output)
139        tgt_mask: Option<impl Into<Tensor>>, // Optional mask for target sequence (self-attention)
140        memory_mask: Option<impl Into<Tensor>>, // Optional mask for memory sequence (cross-attention)
141        tgt_key_padding_mask: Option<impl Into<Tensor>>, // Optional padding mask for target
142        memory_key_padding_mask: Option<impl Into<Tensor>>, // Optional padding mask for memory
143        tgt_is_causal: bool, // Whether to apply causal masking to target (autoregressive)
144        memory_is_causal: bool, // Whether to apply causal masking to memory
145    ) -> Result<Tensor, ZyxError> {
146        let mut output = tgt.clone();
147
148        // Apply LayerNorm first if norm_first is true
149        if self.norm_first {
150            output = self.layer_norm_1.forward(&output)?;
151        }
152
153        // Self-Attention: Apply self-attention to the target sequence
154        let (attn_output, _) = self.self_attention.forward(
155            &output,              // query = tgt
156            &output,              // key = tgt
157            &output,              // value = tgt
158            tgt_key_padding_mask, // Padding mask for tgt
159            true,                 // need_weights = true (return attention weights)
160            tgt_mask,             // tgt_mask (optional)
161            true,                 // average_attn_weights = true
162            tgt_is_causal,        // Is causal self-attention (autoregressive)?
163        )?;
164
165        // Apply dropout after self-attention
166        let attn_output = attn_output.dropout(self.dropout_rate);
167
168        // Add residual connection after self-attention
169        output = output + attn_output;
170
171        // Apply LayerNorm after self-attention if norm_first is false
172        if !self.norm_first {
173            output = self.layer_norm_1.forward(&output)?;
174        }
175
176        // Cross-Attention: Apply cross-attention using memory (encoder output)
177        let (cross_attn_output, _) = self.cross_attention.forward(
178            &output,                 // query = tgt (output from self-attention)
179            memory,                  // key = memory (encoder output)
180            memory,                  // value = memory
181            memory_key_padding_mask, // Padding mask for memory
182            true,                    // need_weights = true (return attention weights)
183            memory_mask,             // memory_mask (optional)
184            true,                    // average_attn_weights = true
185            memory_is_causal,        // Is causal attention for memory
186        )?;
187
188        // Apply dropout after cross-attention
189        let cross_attn_output = cross_attn_output.dropout(self.dropout_rate);
190
191        // Add residual connection after cross-attention
192        output = output + cross_attn_output;
193
194        // Apply LayerNorm after cross-attention if norm_first is false
195        if !self.norm_first {
196            output = self.layer_norm_2.forward(&output)?;
197        }
198
199        // Feedforward Network: Apply the feedforward layer
200        let ff_output = self.feedforward.forward(&output)?;
201
202        // Apply the activation function to the feedforward output
203        let ff_output = (self.activation)(ff_output);
204
205        // Apply dropout after feedforward
206        let ff_output = ff_output.dropout(self.dropout_rate);
207
208        // Add residual connection after feedforward
209        output = output + ff_output;
210
211        // Apply final LayerNorm after feedforward
212        output = self.layer_norm_2.forward(&output)?;
213
214        Ok(output)
215    }
216}