Skip to main content

zyx_nn/
positional_encoding.rs

1// Copyright (C) 2025 zk4x
2// SPDX-License-Identifier: LGPL-3.0-only
3
4use zyx::{DType, Tensor, ZyxError};
5use zyx_derive::Module;
6
7/// Sinusoidal positional encoding module for transformers.
8///
9/// This module adds fixed (non-learnable) positional encodings to input embeddings.
10/// It uses the same formulation as in the original "Attention is All You Need" paper,
11/// based on sine and cosine functions of different frequencies.
12///
13/// It supports both `f32` and `f64` types and applies dropout after adding the encodings.
14#[derive(Debug, Module)]
15#[cfg_attr(feature = "py", pyo3::pyclass)]
16pub struct PositionalEncoding {
17    /// Precomputed positional encodings of shape `[max_len, d_model]`.
18    pe: Tensor,
19
20    /// Dropout probability to apply after adding the positional encoding.
21    dropout_prob: f32,
22}
23
24impl PositionalEncoding {
25    /// Creates a new `PositionalEncoding` module.
26    ///
27    /// # Arguments
28    ///
29    /// * `d_model` - The embedding dimension (must match the input's last dimension).
30    /// * `max_len` - Maximum sequence length this module will support.
31    /// * `dropout_prob` - Dropout probability applied after adding the positional encoding.
32    /// * `dtype` - Data type of the encoding (must be `DType::F32` or `DType::F64`).
33    ///
34    /// # Errors
35    ///
36    /// Returns a [`ZyxError::ShapeError`] if a non-floating-point dtype is used.
37    ///
38    /// # Example
39    ///
40    /// ```rust ignore
41    /// let pe = PositionalEncoding::new(512, 1024, 0.1, DType::F32)?;
42    /// ```
43    pub fn new(
44        d_model: u64,
45        max_len: usize,
46        dropout_prob: f32,
47        dtype: DType,
48    ) -> Result<Self, ZyxError> {
49        // Enforce floating point type
50        match dtype {
51            DType::F32 | DType::F64 => {}
52            _ => {
53                return Err(ZyxError::ShapeError(
54                    "PositionalEncoding requires dtype F32 or F64".into(),
55                ))
56            }
57        }
58
59        // position: [max_len, 1]
60        let position = Tensor::arange(0i64, max_len as i64, 1i64)?
61            .cast(dtype)
62            .unsqueeze(1)?;
63
64        // div_term: [d_model // 2]
65        let div_term_i64 = Tensor::arange(0i64, d_model as i64, 2i64)?;
66        let div_term = div_term_i64.cast(dtype) / Tensor::from(d_model as f64).cast(dtype);
67
68        let div_term = Tensor::from(10000.0f64).pow(&div_term)?; // [d_model // 2]
69
70        let angle_rates = &position / div_term.unsqueeze(0)?; // [max_len, d_model // 2]
71        let sin_part = angle_rates.sin(); // [max_len, d_model // 2]
72        let cos_part = angle_rates.cos(); // [max_len, d_model // 2]
73
74        // Interleave sin and cos: [max_len, d_model]
75        let mut parts = Vec::with_capacity(d_model as usize);
76        for i in 0..(d_model / 2) {
77            parts.push(sin_part.slice((0..max_len, i))?.unsqueeze(1)?);
78            parts.push(cos_part.slice((0..max_len, i))?.unsqueeze(1)?);
79        }
80
81        // Pad if d_model is odd
82        if d_model % 2 != 0 {
83            let pad = sin_part
84                .slice((0..max_len, d_model / 2 - 1))?
85                .unsqueeze(1)?;
86            parts.push(pad);
87        }
88
89        let pe = Tensor::cat(&parts, 1)?; // [max_len, d_model]
90
91        Ok(Self { pe, dropout_prob })
92    }
93
94    /// Applies positional encoding to the input tensor.
95    ///
96    /// # Arguments
97    ///
98    /// * `x` - A tensor of shape `[batch_size, seq_len, d_model]`.
99    ///
100    /// # Returns
101    ///
102    /// A new tensor with the same shape as the input, with positional encodings added and
103    /// dropout applied.
104    ///
105    /// # Errors
106    ///
107    /// Returns a [`ZyxError::ShapeError`] if:
108    /// - Input tensor is not 3-dimensional.
109    /// - The input dimension `d_model` does not match the positional encoding.
110    /// - The sequence length exceeds the configured `max_len`.
111    pub fn forward(&self, x: impl Into<Tensor>) -> Result<Tensor, ZyxError> {
112        let x = x.into();
113        let shape = x.shape();
114
115        if shape.len() != 3 {
116            return Err(ZyxError::ShapeError(
117                "Expected input of shape [batch, seq, dim]".into(),
118            ));
119        }
120
121        let seq_len = shape[1];
122        let dim = shape[2];
123
124        if dim != self.pe.shape()[1] {
125            return Err(ZyxError::ShapeError(
126                format!(
127                    "Mismatch between input dim {} and positional encoding dim {}",
128                    dim,
129                    self.pe.shape()[1]
130                )
131                .into(),
132            ));
133        }
134
135        if seq_len > self.pe.shape()[0] {
136            return Err(ZyxError::ShapeError(
137                format!(
138                    "Input sequence length {} exceeds positional encoding max_len {}",
139                    seq_len,
140                    self.pe.shape()[0]
141                )
142                .into(),
143            ));
144        }
145
146        let pe_slice = self.pe.slice((0..seq_len, 0..dim))?; // [seq_len, dim]
147        let pe_expanded = pe_slice.unsqueeze(0)?; // [1, seq_len, dim]
148
149        let out = (x + pe_expanded).dropout(self.dropout_prob);
150        Ok(out)
151    }
152}