pub struct PositionalEncoding { /* private fields */ }Expand description
Sinusoidal positional encoding module for transformers.
This module adds fixed (non-learnable) positional encodings to input embeddings. It uses the same formulation as in the original “Attention is All You Need” paper, based on sine and cosine functions of different frequencies.
It supports both f32 and f64 types and applies dropout after adding the encodings.
Implementations§
Source§impl PositionalEncoding
impl PositionalEncoding
Sourcepub fn new(
d_model: u64,
max_len: usize,
dropout_prob: f32,
dtype: DType,
) -> Result<Self, ZyxError>
pub fn new( d_model: u64, max_len: usize, dropout_prob: f32, dtype: DType, ) -> Result<Self, ZyxError>
Creates a new PositionalEncoding module.
§Arguments
d_model- The embedding dimension (must match the input’s last dimension).max_len- Maximum sequence length this module will support.dropout_prob- Dropout probability applied after adding the positional encoding.dtype- Data type of the encoding (must beDType::F32orDType::F64).
§Errors
Returns a ZyxError::ShapeError if a non-floating-point dtype is used.
§Example
ⓘ
let pe = PositionalEncoding::new(512, 1024, 0.1, DType::F32)?;Sourcepub fn forward(&self, x: impl Into<Tensor>) -> Result<Tensor, ZyxError>
pub fn forward(&self, x: impl Into<Tensor>) -> Result<Tensor, ZyxError>
Applies positional encoding to the input tensor.
§Arguments
x- A tensor of shape[batch_size, seq_len, d_model].
§Returns
A new tensor with the same shape as the input, with positional encodings added and dropout applied.
§Errors
Returns a ZyxError::ShapeError if:
- Input tensor is not 3-dimensional.
- The input dimension
d_modeldoes not match the positional encoding. - The sequence length exceeds the configured
max_len.
Trait Implementations§
Source§impl Debug for PositionalEncoding
impl Debug for PositionalEncoding
Source§impl<'a> IntoIterator for &'a PositionalEncoding
impl<'a> IntoIterator for &'a PositionalEncoding
Source§impl<'a> IntoIterator for &'a mut PositionalEncoding
impl<'a> IntoIterator for &'a mut PositionalEncoding
Source§impl Module for PositionalEncoding
impl Module for PositionalEncoding
Source§fn iter_mut<'a>(&'a mut self) -> impl Iterator<Item = &'a mut Tensor>
fn iter_mut<'a>(&'a mut self) -> impl Iterator<Item = &'a mut Tensor>
Iterate over all tensors mutably
Source§fn iter_tensors<'a>(&'a self) -> impl Iterator<Item = (String, &'a Tensor)>
fn iter_tensors<'a>(&'a self) -> impl Iterator<Item = (String, &'a Tensor)>
Iterate over tensors without consuming the module
Source§fn iter_tensors_mut<'a>(
&'a mut self,
) -> impl Iterator<Item = (String, &'a mut Tensor)>
fn iter_tensors_mut<'a>( &'a mut self, ) -> impl Iterator<Item = (String, &'a mut Tensor)>
From tensors
Auto Trait Implementations§
impl Freeze for PositionalEncoding
impl RefUnwindSafe for PositionalEncoding
impl Send for PositionalEncoding
impl Sync for PositionalEncoding
impl Unpin for PositionalEncoding
impl UnsafeUnpin for PositionalEncoding
impl UnwindSafe for PositionalEncoding
Blanket Implementations§
Source§impl<T> BorrowMut<T> for Twhere
T: ?Sized,
impl<T> BorrowMut<T> for Twhere
T: ?Sized,
Source§fn borrow_mut(&mut self) -> &mut T
fn borrow_mut(&mut self) -> &mut T
Mutably borrows from an owned value. Read more