Skip to main content

PositionalEncoding

Struct PositionalEncoding 

Source
pub struct PositionalEncoding { /* private fields */ }
Expand description

Sinusoidal positional encoding module for transformers.

This module adds fixed (non-learnable) positional encodings to input embeddings. It uses the same formulation as in the original “Attention is All You Need” paper, based on sine and cosine functions of different frequencies.

It supports both f32 and f64 types and applies dropout after adding the encodings.

Implementations§

Source§

impl PositionalEncoding

Source

pub fn new( d_model: u64, max_len: usize, dropout_prob: f32, dtype: DType, ) -> Result<Self, ZyxError>

Creates a new PositionalEncoding module.

§Arguments
  • d_model - The embedding dimension (must match the input’s last dimension).
  • max_len - Maximum sequence length this module will support.
  • dropout_prob - Dropout probability applied after adding the positional encoding.
  • dtype - Data type of the encoding (must be DType::F32 or DType::F64).
§Errors

Returns a ZyxError::ShapeError if a non-floating-point dtype is used.

§Example
let pe = PositionalEncoding::new(512, 1024, 0.1, DType::F32)?;
Source

pub fn forward(&self, x: impl Into<Tensor>) -> Result<Tensor, ZyxError>

Applies positional encoding to the input tensor.

§Arguments
  • x - A tensor of shape [batch_size, seq_len, d_model].
§Returns

A new tensor with the same shape as the input, with positional encodings added and dropout applied.

§Errors

Returns a ZyxError::ShapeError if:

  • Input tensor is not 3-dimensional.
  • The input dimension d_model does not match the positional encoding.
  • The sequence length exceeds the configured max_len.

Trait Implementations§

Source§

impl Debug for PositionalEncoding

Source§

fn fmt(&self, f: &mut Formatter<'_>) -> Result

Formats the value using the given formatter. Read more
Source§

impl<'a> IntoIterator for &'a PositionalEncoding

Source§

type Item = &'a Tensor

The type of the elements being iterated over.
Source§

type IntoIter = IntoIter<&'a Tensor>

Which kind of iterator are we turning this into?
Source§

fn into_iter(self) -> Self::IntoIter

Creates an iterator from a value. Read more
Source§

impl<'a> IntoIterator for &'a mut PositionalEncoding

Source§

type Item = &'a mut Tensor

The type of the elements being iterated over.
Source§

type IntoIter = IntoIter<&'a mut Tensor>

Which kind of iterator are we turning this into?
Source§

fn into_iter(self) -> Self::IntoIter

Creates an iterator from a value. Read more
Source§

impl Module for PositionalEncoding

Source§

fn iter<'a>(&'a self) -> impl Iterator<Item = &'a Tensor>

Iterate over all tensors immutably
Source§

fn iter_mut<'a>(&'a mut self) -> impl Iterator<Item = &'a mut Tensor>

Iterate over all tensors mutably
Source§

fn iter_tensors<'a>(&'a self) -> impl Iterator<Item = (String, &'a Tensor)>

Iterate over tensors without consuming the module
Source§

fn iter_tensors_mut<'a>( &'a mut self, ) -> impl Iterator<Item = (String, &'a mut Tensor)>

From tensors
Source§

fn realize(&self) -> Result<(), ZyxError>

Realize all tensors in the module Read more
Source§

fn set_params(&mut self, params: &mut HashMap<String, Tensor>)

Set parameters, removes them from params, skips parameters that are not found in params.
Source§

fn save(&self, path: impl AsRef<Path>) -> Result<(), ZyxError>

Save tensors or modules to a file determined by file extension. Currently only safetensors is supported format. Read more

Auto Trait Implementations§

Blanket Implementations§

Source§

impl<T> Any for T
where T: 'static + ?Sized,

Source§

fn type_id(&self) -> TypeId

Gets the TypeId of self. Read more
Source§

impl<T> Borrow<T> for T
where T: ?Sized,

Source§

fn borrow(&self) -> &T

Immutably borrows from an owned value. Read more
Source§

impl<T> BorrowMut<T> for T
where T: ?Sized,

Source§

fn borrow_mut(&mut self) -> &mut T

Mutably borrows from an owned value. Read more
Source§

impl<T> From<T> for T

Source§

fn from(t: T) -> T

Returns the argument unchanged.

Source§

impl<T, U> Into<U> for T
where U: From<T>,

Source§

fn into(self) -> U

Calls U::from(self).

That is, this conversion is whatever the implementation of From<T> for U chooses to do.

Source§

impl<T, U> TryFrom<U> for T
where U: Into<T>,

Source§

type Error = Infallible

The type returned in the event of a conversion error.
Source§

fn try_from(value: U) -> Result<T, <T as TryFrom<U>>::Error>

Performs the conversion.
Source§

impl<T, U> TryInto<U> for T
where U: TryFrom<T>,

Source§

type Error = <U as TryFrom<T>>::Error

The type returned in the event of a conversion error.
Source§

fn try_into(self) -> Result<U, <U as TryFrom<T>>::Error>

Performs the conversion.