Skip to main content

zyx_nn/
transformer_encoder_layer.rs

1// Copyright (C) 2025 zk4x
2// SPDX-License-Identifier: LGPL-3.0-only
3
4use crate::{LayerNorm, Linear, MultiheadAttention};
5use zyx::{DType, Tensor, ZyxError};
6use zyx_derive::Module;
7
8/// A single Transformer Encoder layer, analogous to `torch.nn.TransformerEncoderLayer`.
9///
10/// This layer implements a standard Transformer encoder block with a multi-head self-attention
11/// mechanism followed by a position-wise feedforward network. Layer normalization can be applied
12/// either before ("pre-norm") or after ("post-norm") the attention and feedforward sub-layers.
13#[derive(Debug, Module)]
14#[cfg_attr(feature = "py", pyo3::pyclass)]
15pub struct TransformerEncoderLayer {
16    /// - `self_attn`: The multi-head self-attention module.
17    pub self_attn: MultiheadAttention,
18    /// - `linear1`: The first linear layer of the feedforward network (expansion).
19    pub linear1: Linear,
20    /// - `dropout`: Dropout probability applied after attention and feedforward layers.
21    pub dropout: f32,
22    /// - `linear2`: The second linear layer of the feedforward network (projection back to `d_model`).
23    pub linear2: Linear,
24    /// - `norm1`: LayerNorm applied after the self-attention block (or before if `norm_first` is true).
25    pub norm1: LayerNorm,
26    /// - `norm2`: LayerNorm applied after the feedforward block (or before if `norm_first` is true).
27    pub norm2: LayerNorm,
28    /// - `activation`: The activation function used in the feedforward network (e.g., ReLU, GELU).
29    pub activation: fn(Tensor) -> Tensor,
30    /// - `norm_first`: If `true`, applies layer normalization before each sub-layer (pre-norm).
31    pub norm_first: bool,
32    /// - `batch_first`: If `true`, expects input tensors of shape `(batch_size, seq_len, d_model)`.
33    pub batch_first: bool,
34}
35
36impl TransformerEncoderLayer {
37    /// Constructs a new `TransformerEncoderLayer`.
38    ///
39    /// # Arguments
40    ///
41    /// * `d_model` - The number of expected features in the input (embedding size).
42    /// * `nhead` - The number of attention heads.
43    /// * `dim_feedforward` - The dimension of the feedforward network.
44    /// * `dropout` - Dropout probability applied after attention and feedforward layers.
45    /// * `activation` - Activation function used in the feedforward network.
46    /// * `layer_norm_eps` - Epsilon value for numerical stability in layer normalization.
47    /// * `batch_first` - If `true`, input/output tensors are expected in `(batch, seq, feature)` format.
48    /// * `norm_first` - If `true`, applies layer normalization before sub-layers (pre-norm).
49    /// * `bias` - If `true`, linear layers include bias terms.
50    /// * `dtype` - The data type of the layer’s parameters and outputs.
51    ///
52    /// # Returns
53    ///
54    /// A `Result` containing the initialized `TransformerEncoderLayer` or a `ZyxError`.
55    pub fn new(
56        d_model: u64,
57        nhead: u64,
58        dim_feedforward: u64,
59        dropout: f32,
60        activation: fn(Tensor) -> Tensor,
61        layer_norm_eps: f64,
62        batch_first: bool,
63        norm_first: bool,
64        bias: bool,
65        dtype: DType,
66    ) -> Result<Self, ZyxError> {
67        // --- Multihead self-attention ---
68        let self_attn = MultiheadAttention::new(
69            d_model,
70            nhead,
71            dropout,
72            bias,
73            /* add_bias_kv */ false,
74            /* add_zero_attn */ false,
75            /* kdim */ None,
76            /* vdim */ None,
77            batch_first,
78            dtype,
79        )?;
80
81        // --- Feedforward network ---
82        let linear1 = Linear::new(d_model, dim_feedforward, bias, dtype)?;
83        let linear2 = Linear::new(dim_feedforward, d_model, bias, dtype)?;
84
85        // --- LayerNorms ---
86        let norm1 = LayerNorm::new(d_model, layer_norm_eps, true, bias, dtype)?;
87        let norm2 = LayerNorm::new(d_model, layer_norm_eps, true, bias, dtype)?;
88
89        Ok(Self {
90            self_attn,
91            linear1,
92            dropout,
93            linear2,
94            norm1,
95            norm2,
96            activation,
97            norm_first,
98            batch_first,
99        })
100    }
101
102    /// Performs a forward pass of the Transformer encoder layer.
103    ///
104    /// # Arguments
105    ///
106    /// * `src` - Input tensor of shape `(seq_len, batch_size, d_model)` or `(batch_size, seq_len, d_model)` if `batch_first`.
107    /// * `src_mask` - Optional attention mask tensor to prevent attention to certain positions.
108    /// * `src_key_padding_mask` - Optional mask tensor for padding positions in the input.
109    ///
110    /// # Returns
111    ///
112    /// A `Result` containing the output tensor after applying self-attention and feedforward blocks.
113    pub fn forward(
114        &self,
115        src: impl Into<Tensor>,
116        src_mask: Option<Tensor>,
117        src_key_padding_mask: Option<Tensor>,
118    ) -> Result<Tensor, ZyxError> {
119        let mut x = src.into();
120
121        if self.norm_first {
122            // Pre-norm variant
123            let sa_out = self.self_attention_block(
124                self.norm1.forward(&x)?,
125                &src_mask,
126                &src_key_padding_mask,
127            )?;
128            x = x + sa_out;
129
130            let ff_out = self.feed_forward_block(self.norm2.forward(&x)?)?;
131            x = x + ff_out;
132        } else {
133            // Post-norm variant
134            let sa_out = self.self_attention_block(&x, &src_mask, &src_key_padding_mask)?;
135            x = self.norm1.forward(x + sa_out)?;
136
137            let ff_out = self.feed_forward_block(&x)?;
138            x = self.norm2.forward(x + ff_out)?;
139        }
140
141        Ok(x)
142    }
143
144    fn self_attention_block(
145        &self,
146        x: impl Into<Tensor>,
147        src_mask: &Option<Tensor>,
148        src_key_padding_mask: &Option<Tensor>,
149    ) -> Result<Tensor, ZyxError> {
150        let x = x.into();
151        let (attn_output, _weights) = self.self_attn.forward(
152            x.clone(),
153            x.clone(), // key = query = value
154            x,
155            src_key_padding_mask.as_ref(),
156            /* need_weights */ false,
157            src_mask.as_ref(),
158            /* average_attn_weights */ true,
159            /* is_causal */ false,
160        )?;
161        // Dropout after attention output
162        Ok(attn_output.dropout(self.dropout))
163    }
164
165    fn feed_forward_block(&self, x: impl Into<Tensor>) -> Result<Tensor, ZyxError> {
166        let x = x.into();
167        let x = self.linear1.forward(&x)?;
168        let x = (self.activation)(x);
169        let x = x.dropout(self.dropout);
170        let x = self.linear2.forward(&x)?;
171        Ok(x.dropout(self.dropout))
172    }
173}