zyx_nn/transformer_encoder_layer.rs
1// Copyright (C) 2025 zk4x
2// SPDX-License-Identifier: LGPL-3.0-only
3
4use crate::{LayerNorm, Linear, MultiheadAttention};
5use zyx::{DType, Tensor, ZyxError};
6use zyx_derive::Module;
7
8/// A single Transformer Encoder layer, analogous to `torch.nn.TransformerEncoderLayer`.
9///
10/// This layer implements a standard Transformer encoder block with a multi-head self-attention
11/// mechanism followed by a position-wise feedforward network. Layer normalization can be applied
12/// either before ("pre-norm") or after ("post-norm") the attention and feedforward sub-layers.
13#[derive(Debug, Module)]
14#[cfg_attr(feature = "py", pyo3::pyclass)]
15pub struct TransformerEncoderLayer {
16 /// - `self_attn`: The multi-head self-attention module.
17 pub self_attn: MultiheadAttention,
18 /// - `linear1`: The first linear layer of the feedforward network (expansion).
19 pub linear1: Linear,
20 /// - `dropout`: Dropout probability applied after attention and feedforward layers.
21 pub dropout: f32,
22 /// - `linear2`: The second linear layer of the feedforward network (projection back to `d_model`).
23 pub linear2: Linear,
24 /// - `norm1`: LayerNorm applied after the self-attention block (or before if `norm_first` is true).
25 pub norm1: LayerNorm,
26 /// - `norm2`: LayerNorm applied after the feedforward block (or before if `norm_first` is true).
27 pub norm2: LayerNorm,
28 /// - `activation`: The activation function used in the feedforward network (e.g., ReLU, GELU).
29 pub activation: fn(Tensor) -> Tensor,
30 /// - `norm_first`: If `true`, applies layer normalization before each sub-layer (pre-norm).
31 pub norm_first: bool,
32 /// - `batch_first`: If `true`, expects input tensors of shape `(batch_size, seq_len, d_model)`.
33 pub batch_first: bool,
34}
35
36impl TransformerEncoderLayer {
37 /// Constructs a new `TransformerEncoderLayer`.
38 ///
39 /// # Arguments
40 ///
41 /// * `d_model` - The number of expected features in the input (embedding size).
42 /// * `nhead` - The number of attention heads.
43 /// * `dim_feedforward` - The dimension of the feedforward network.
44 /// * `dropout` - Dropout probability applied after attention and feedforward layers.
45 /// * `activation` - Activation function used in the feedforward network.
46 /// * `layer_norm_eps` - Epsilon value for numerical stability in layer normalization.
47 /// * `batch_first` - If `true`, input/output tensors are expected in `(batch, seq, feature)` format.
48 /// * `norm_first` - If `true`, applies layer normalization before sub-layers (pre-norm).
49 /// * `bias` - If `true`, linear layers include bias terms.
50 /// * `dtype` - The data type of the layer’s parameters and outputs.
51 ///
52 /// # Returns
53 ///
54 /// A `Result` containing the initialized `TransformerEncoderLayer` or a `ZyxError`.
55 pub fn new(
56 d_model: u64,
57 nhead: u64,
58 dim_feedforward: u64,
59 dropout: f32,
60 activation: fn(Tensor) -> Tensor,
61 layer_norm_eps: f64,
62 batch_first: bool,
63 norm_first: bool,
64 bias: bool,
65 dtype: DType,
66 ) -> Result<Self, ZyxError> {
67 // --- Multihead self-attention ---
68 let self_attn = MultiheadAttention::new(
69 d_model,
70 nhead,
71 dropout,
72 bias,
73 /* add_bias_kv */ false,
74 /* add_zero_attn */ false,
75 /* kdim */ None,
76 /* vdim */ None,
77 batch_first,
78 dtype,
79 )?;
80
81 // --- Feedforward network ---
82 let linear1 = Linear::new(d_model, dim_feedforward, bias, dtype)?;
83 let linear2 = Linear::new(dim_feedforward, d_model, bias, dtype)?;
84
85 // --- LayerNorms ---
86 let norm1 = LayerNorm::new(d_model, layer_norm_eps, true, bias, dtype)?;
87 let norm2 = LayerNorm::new(d_model, layer_norm_eps, true, bias, dtype)?;
88
89 Ok(Self {
90 self_attn,
91 linear1,
92 dropout,
93 linear2,
94 norm1,
95 norm2,
96 activation,
97 norm_first,
98 batch_first,
99 })
100 }
101
102 /// Performs a forward pass of the Transformer encoder layer.
103 ///
104 /// # Arguments
105 ///
106 /// * `src` - Input tensor of shape `(seq_len, batch_size, d_model)` or `(batch_size, seq_len, d_model)` if `batch_first`.
107 /// * `src_mask` - Optional attention mask tensor to prevent attention to certain positions.
108 /// * `src_key_padding_mask` - Optional mask tensor for padding positions in the input.
109 ///
110 /// # Returns
111 ///
112 /// A `Result` containing the output tensor after applying self-attention and feedforward blocks.
113 pub fn forward(
114 &self,
115 src: impl Into<Tensor>,
116 src_mask: Option<Tensor>,
117 src_key_padding_mask: Option<Tensor>,
118 ) -> Result<Tensor, ZyxError> {
119 let mut x = src.into();
120
121 if self.norm_first {
122 // Pre-norm variant
123 let sa_out = self.self_attention_block(
124 self.norm1.forward(&x)?,
125 &src_mask,
126 &src_key_padding_mask,
127 )?;
128 x = x + sa_out;
129
130 let ff_out = self.feed_forward_block(self.norm2.forward(&x)?)?;
131 x = x + ff_out;
132 } else {
133 // Post-norm variant
134 let sa_out = self.self_attention_block(&x, &src_mask, &src_key_padding_mask)?;
135 x = self.norm1.forward(x + sa_out)?;
136
137 let ff_out = self.feed_forward_block(&x)?;
138 x = self.norm2.forward(x + ff_out)?;
139 }
140
141 Ok(x)
142 }
143
144 fn self_attention_block(
145 &self,
146 x: impl Into<Tensor>,
147 src_mask: &Option<Tensor>,
148 src_key_padding_mask: &Option<Tensor>,
149 ) -> Result<Tensor, ZyxError> {
150 let x = x.into();
151 let (attn_output, _weights) = self.self_attn.forward(
152 x.clone(),
153 x.clone(), // key = query = value
154 x,
155 src_key_padding_mask.as_ref(),
156 /* need_weights */ false,
157 src_mask.as_ref(),
158 /* average_attn_weights */ true,
159 /* is_causal */ false,
160 )?;
161 // Dropout after attention output
162 Ok(attn_output.dropout(self.dropout))
163 }
164
165 fn feed_forward_block(&self, x: impl Into<Tensor>) -> Result<Tensor, ZyxError> {
166 let x = x.into();
167 let x = self.linear1.forward(&x)?;
168 let x = (self.activation)(x);
169 let x = x.dropout(self.dropout);
170 let x = self.linear2.forward(&x)?;
171 Ok(x.dropout(self.dropout))
172 }
173}