Trait PositionalEncoding

Source
pub trait PositionalEncoding<F: Float + Debug> {
    // Required methods
    fn forward(&self, embeddings: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>>;
    fn get_encoding(&self, seq_len: usize) -> Result<Array<F, IxDyn>>;
    fn update(&mut self, learning_rate: F) -> Result<()>;

    // Provided methods
    fn params(&self) -> Vec<Array<F, IxDyn>>  { ... }
    fn set_training(&mut self, _training: bool) { ... }
    fn is_training(&self) -> bool { ... }
}
Expand description

Trait for positional encoding implementations

Required Methods§

Source

fn forward(&self, embeddings: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>>

Apply positional encoding to input embeddings

§Arguments
  • embeddings - Input embeddings [batch, seq_len, d_model]
§Returns
  • Embeddings with positional encoding added
Source

fn get_encoding(&self, seq_len: usize) -> Result<Array<F, IxDyn>>

Get the positional encoding matrix directly

§Arguments
  • seq_len - Sequence length to generate encodings for
§Returns
  • Positional encoding matrix [seq_len, d_model]
Source

fn update(&mut self, learning_rate: F) -> Result<()>

Update learnable parameters if any

§Arguments
  • learning_rate - Learning rate for the update

Provided Methods§

Source

fn params(&self) -> Vec<Array<F, IxDyn>>

Get learnable parameters if any

§Returns
  • Vector of parameters as arrays
Source

fn set_training(&mut self, _training: bool)

Set training mode (does nothing by default)

Source

fn is_training(&self) -> bool

Get training mode (false by default)

Implementors§