zyx_nn/positional_encoding.rs
1// Copyright (C) 2025 zk4x
2// SPDX-License-Identifier: LGPL-3.0-only
3
4use zyx::{DType, Tensor, ZyxError};
5use zyx_derive::Module;
6
7/// Sinusoidal positional encoding module for transformers.
8///
9/// This module adds fixed (non-learnable) positional encodings to input embeddings.
10/// It uses the same formulation as in the original "Attention is All You Need" paper,
11/// based on sine and cosine functions of different frequencies.
12///
13/// It supports both `f32` and `f64` types and applies dropout after adding the encodings.
14#[derive(Debug, Module)]
15#[cfg_attr(feature = "py", pyo3::pyclass)]
16pub struct PositionalEncoding {
17 /// Precomputed positional encodings of shape `[max_len, d_model]`.
18 pe: Tensor,
19
20 /// Dropout probability to apply after adding the positional encoding.
21 dropout_prob: f32,
22}
23
24impl PositionalEncoding {
25 /// Creates a new `PositionalEncoding` module.
26 ///
27 /// # Arguments
28 ///
29 /// * `d_model` - The embedding dimension (must match the input's last dimension).
30 /// * `max_len` - Maximum sequence length this module will support.
31 /// * `dropout_prob` - Dropout probability applied after adding the positional encoding.
32 /// * `dtype` - Data type of the encoding (must be `DType::F32` or `DType::F64`).
33 ///
34 /// # Errors
35 ///
36 /// Returns a [`ZyxError::ShapeError`] if a non-floating-point dtype is used.
37 ///
38 /// # Example
39 ///
40 /// ```rust ignore
41 /// let pe = PositionalEncoding::new(512, 1024, 0.1, DType::F32)?;
42 /// ```
43 pub fn new(
44 d_model: u64,
45 max_len: usize,
46 dropout_prob: f32,
47 dtype: DType,
48 ) -> Result<Self, ZyxError> {
49 // Enforce floating point type
50 match dtype {
51 DType::F32 | DType::F64 => {}
52 _ => {
53 return Err(ZyxError::ShapeError(
54 "PositionalEncoding requires dtype F32 or F64".into(),
55 ))
56 }
57 }
58
59 // position: [max_len, 1]
60 let position = Tensor::arange(0i64, max_len as i64, 1i64)?
61 .cast(dtype)
62 .unsqueeze(1)?;
63
64 // div_term: [d_model // 2]
65 let div_term_i64 = Tensor::arange(0i64, d_model as i64, 2i64)?;
66 let div_term = div_term_i64.cast(dtype) / Tensor::from(d_model as f64).cast(dtype);
67
68 let div_term = Tensor::from(10000.0f64).pow(&div_term)?; // [d_model // 2]
69
70 let angle_rates = &position / div_term.unsqueeze(0)?; // [max_len, d_model // 2]
71 let sin_part = angle_rates.sin(); // [max_len, d_model // 2]
72 let cos_part = angle_rates.cos(); // [max_len, d_model // 2]
73
74 // Interleave sin and cos: [max_len, d_model]
75 let mut parts = Vec::with_capacity(d_model as usize);
76 for i in 0..(d_model / 2) {
77 parts.push(sin_part.slice((0..max_len, i))?.unsqueeze(1)?);
78 parts.push(cos_part.slice((0..max_len, i))?.unsqueeze(1)?);
79 }
80
81 // Pad if d_model is odd
82 if d_model % 2 != 0 {
83 let pad = sin_part
84 .slice((0..max_len, d_model / 2 - 1))?
85 .unsqueeze(1)?;
86 parts.push(pad);
87 }
88
89 let pe = Tensor::cat(&parts, 1)?; // [max_len, d_model]
90
91 Ok(Self { pe, dropout_prob })
92 }
93
94 /// Applies positional encoding to the input tensor.
95 ///
96 /// # Arguments
97 ///
98 /// * `x` - A tensor of shape `[batch_size, seq_len, d_model]`.
99 ///
100 /// # Returns
101 ///
102 /// A new tensor with the same shape as the input, with positional encodings added and
103 /// dropout applied.
104 ///
105 /// # Errors
106 ///
107 /// Returns a [`ZyxError::ShapeError`] if:
108 /// - Input tensor is not 3-dimensional.
109 /// - The input dimension `d_model` does not match the positional encoding.
110 /// - The sequence length exceeds the configured `max_len`.
111 pub fn forward(&self, x: impl Into<Tensor>) -> Result<Tensor, ZyxError> {
112 let x = x.into();
113 let shape = x.shape();
114
115 if shape.len() != 3 {
116 return Err(ZyxError::ShapeError(
117 "Expected input of shape [batch, seq, dim]".into(),
118 ));
119 }
120
121 let seq_len = shape[1];
122 let dim = shape[2];
123
124 if dim != self.pe.shape()[1] {
125 return Err(ZyxError::ShapeError(
126 format!(
127 "Mismatch between input dim {} and positional encoding dim {}",
128 dim,
129 self.pe.shape()[1]
130 )
131 .into(),
132 ));
133 }
134
135 if seq_len > self.pe.shape()[0] {
136 return Err(ZyxError::ShapeError(
137 format!(
138 "Input sequence length {} exceeds positional encoding max_len {}",
139 seq_len,
140 self.pe.shape()[0]
141 )
142 .into(),
143 ));
144 }
145
146 let pe_slice = self.pe.slice((0..seq_len, 0..dim))?; // [seq_len, dim]
147 let pe_expanded = pe_slice.unsqueeze(0)?; // [1, seq_len, dim]
148
149 let out = (x + pe_expanded).dropout(self.dropout_prob);
150 Ok(out)
151 }
152}