zyx_nn/transformer_decoder_layer.rs
1// Copyright (C) 2025 zk4x
2// SPDX-License-Identifier: LGPL-3.0-only
3
4use crate::{LayerNorm, Linear, MultiheadAttention};
5use zyx::{DType, Tensor, ZyxError};
6use zyx_derive::Module;
7
8/// A single layer of a Transformer decoder.
9///
10/// This layer implements the standard Transformer decoder operations:
11/// 1. **Self-attention** on the target sequence.
12/// 2. **Cross-attention** using the encoder output (memory).
13/// 3. **Feedforward network** with activation function.
14/// 4. **Residual connections** and **Layer Normalization**.
15///
16/// The behavior of the layer can be adjusted using `norm_first` (pre-norm vs post-norm),
17/// dropout rate, and activation function.
18#[derive(Debug, Module)]
19#[cfg_attr(feature = "py", pyo3::pyclass)]
20pub struct TransformerDecoderLayer {
21 self_attention: MultiheadAttention,
22 cross_attention: MultiheadAttention,
23 feedforward: Linear,
24 layer_norm_1: LayerNorm,
25 layer_norm_2: LayerNorm,
26 dropout_rate: f32, // Dropout rate passed as a parameter
27 norm_first: bool, // Whether to apply norm before layers or after
28 activation: fn(Tensor) -> Tensor, // Activation function
29}
30
31impl TransformerDecoderLayer {
32 /// Creates a new `TransformerDecoderLayer`.
33 ///
34 /// # Arguments
35 ///
36 /// * `d_model` - Dimensionality of input embeddings (number of features per token).
37 /// * `nhead` - Number of attention heads in self-attention and cross-attention.
38 /// * `dim_feedforward` - Hidden dimension of the feedforward network.
39 /// * `dropout` - Dropout probability applied after attention and feedforward layers.
40 /// * `activation` - Activation function applied after the feedforward network (e.g., ReLU).
41 /// * `layer_norm_eps` - Small epsilon value for numerical stability in layer normalization.
42 /// * `batch_first` - If true, input tensors have shape `[batch, seq, feature]`. Otherwise `[seq, batch, feature]`.
43 /// * `norm_first` - Whether to apply layer normalization before sub-layers (pre-norm) or after (post-norm).
44 /// * `bias` - Whether to include bias terms in linear and attention layers.
45 /// * `dtype` - Data type of tensors (e.g., `DType::F32`, `DType::F64`).
46 ///
47 /// # Returns
48 ///
49 /// Returns a `Result` containing the new `TransformerDecoderLayer` or a `ZyxError` if initialization fails.
50 pub fn new(
51 d_model: u64, // embed_dim
52 nhead: u64, // num_heads
53 dim_feedforward: u64, // dim_feedforward
54 dropout: f32, // dropout rate
55 activation: fn(Tensor) -> Tensor, // activation function
56 layer_norm_eps: f64, // layer_norm_eps
57 batch_first: bool, // batch_first
58 norm_first: bool, // norm_first
59 bias: bool, // use biases in layers
60 dtype: DType, // tensor data type (e.g., f32, f64)
61 ) -> Result<Self, ZyxError> {
62 // Create self-attention and cross-attention layers
63 let self_attention = MultiheadAttention::new(
64 d_model,
65 nhead,
66 dropout,
67 bias,
68 false,
69 false,
70 None,
71 None,
72 batch_first,
73 dtype,
74 )?;
75
76 let cross_attention = MultiheadAttention::new(
77 d_model,
78 nhead,
79 dropout,
80 bias,
81 false,
82 false,
83 None,
84 None,
85 batch_first,
86 dtype,
87 )?;
88
89 // Create feedforward layers (two Linear layers with ReLU activation)
90 let feedforward = Linear::new(d_model, dim_feedforward, bias, dtype)?;
91
92 // LayerNorms (First for attention, second for feedforward)
93 let layer_norm_1 = LayerNorm::new(d_model, layer_norm_eps, true, bias, dtype)?;
94 let layer_norm_2 = LayerNorm::new(d_model, layer_norm_eps, true, bias, dtype)?;
95
96 // Return the complete TransformerDecoderLayer
97 Ok(TransformerDecoderLayer {
98 self_attention,
99 cross_attention,
100 feedforward,
101 layer_norm_1,
102 layer_norm_2,
103 dropout_rate: dropout,
104 norm_first,
105 activation,
106 })
107 }
108
109 /// Performs a forward pass through the decoder layer.
110 ///
111 /// # Arguments
112 ///
113 /// * `tgt` - Target sequence tensor (decoder input).
114 /// * `memory` - Memory tensor from the encoder (encoder output).
115 /// * `tgt_mask` - Optional mask for self-attention on the target sequence.
116 /// * `memory_mask` - Optional mask for cross-attention on the memory sequence.
117 /// * `tgt_key_padding_mask` - Optional padding mask for target tokens.
118 /// * `memory_key_padding_mask` - Optional padding mask for memory tokens.
119 /// * `tgt_is_causal` - Whether to apply causal masking to target self-attention (autoregressive decoding).
120 /// * `memory_is_causal` - Whether to apply causal masking in cross-attention.
121 ///
122 /// # Returns
123 ///
124 /// Returns a `Result` containing the output tensor of the decoder layer or a `ZyxError`.
125 ///
126 /// # Behavior
127 ///
128 /// 1. Applies layer normalization if `norm_first` is true.
129 /// 2. Applies self-attention on the target sequence.
130 /// 3. Applies residual connection and dropout.
131 /// 4. Applies cross-attention with the encoder memory.
132 /// 5. Applies residual connection and dropout.
133 /// 6. Passes through feedforward network with activation.
134 /// 7. Applies final residual connection and layer normalization.
135 pub fn forward(
136 &self,
137 tgt: &Tensor, // Target sequence (input to the decoder)
138 memory: &Tensor, // Memory sequence (encoder output)
139 tgt_mask: Option<impl Into<Tensor>>, // Optional mask for target sequence (self-attention)
140 memory_mask: Option<impl Into<Tensor>>, // Optional mask for memory sequence (cross-attention)
141 tgt_key_padding_mask: Option<impl Into<Tensor>>, // Optional padding mask for target
142 memory_key_padding_mask: Option<impl Into<Tensor>>, // Optional padding mask for memory
143 tgt_is_causal: bool, // Whether to apply causal masking to target (autoregressive)
144 memory_is_causal: bool, // Whether to apply causal masking to memory
145 ) -> Result<Tensor, ZyxError> {
146 let mut output = tgt.clone();
147
148 // Apply LayerNorm first if norm_first is true
149 if self.norm_first {
150 output = self.layer_norm_1.forward(&output)?;
151 }
152
153 // Self-Attention: Apply self-attention to the target sequence
154 let (attn_output, _) = self.self_attention.forward(
155 &output, // query = tgt
156 &output, // key = tgt
157 &output, // value = tgt
158 tgt_key_padding_mask, // Padding mask for tgt
159 true, // need_weights = true (return attention weights)
160 tgt_mask, // tgt_mask (optional)
161 true, // average_attn_weights = true
162 tgt_is_causal, // Is causal self-attention (autoregressive)?
163 )?;
164
165 // Apply dropout after self-attention
166 let attn_output = attn_output.dropout(self.dropout_rate);
167
168 // Add residual connection after self-attention
169 output = output + attn_output;
170
171 // Apply LayerNorm after self-attention if norm_first is false
172 if !self.norm_first {
173 output = self.layer_norm_1.forward(&output)?;
174 }
175
176 // Cross-Attention: Apply cross-attention using memory (encoder output)
177 let (cross_attn_output, _) = self.cross_attention.forward(
178 &output, // query = tgt (output from self-attention)
179 memory, // key = memory (encoder output)
180 memory, // value = memory
181 memory_key_padding_mask, // Padding mask for memory
182 true, // need_weights = true (return attention weights)
183 memory_mask, // memory_mask (optional)
184 true, // average_attn_weights = true
185 memory_is_causal, // Is causal attention for memory
186 )?;
187
188 // Apply dropout after cross-attention
189 let cross_attn_output = cross_attn_output.dropout(self.dropout_rate);
190
191 // Add residual connection after cross-attention
192 output = output + cross_attn_output;
193
194 // Apply LayerNorm after cross-attention if norm_first is false
195 if !self.norm_first {
196 output = self.layer_norm_2.forward(&output)?;
197 }
198
199 // Feedforward Network: Apply the feedforward layer
200 let ff_output = self.feedforward.forward(&output)?;
201
202 // Apply the activation function to the feedforward output
203 let ff_output = (self.activation)(ff_output);
204
205 // Apply dropout after feedforward
206 let ff_output = ff_output.dropout(self.dropout_rate);
207
208 // Add residual connection after feedforward
209 output = output + ff_output;
210
211 // Apply final LayerNorm after feedforward
212 output = self.layer_norm_2.forward(&output)?;
213
214 Ok(output)
215 }
216}