Struct MultiHeadAttention

Source
pub struct MultiHeadAttention<F: Float + Debug> { /* private fields */ }
Expand description

Multi-head attention layer as used in transformer architectures

This layer performs the attention operation described in “Attention Is All You Need” by Vaswani et al. It projects the queries, keys, and values into multiple heads, computes scaled dot-product attention for each head, concatenates the results, and projects the result back to the original dimension.

§Examples

use scirs2_neural::layers::{MultiHeadAttention, Layer};
use scirs2_neural::layers::AttentionConfig;
use ndarray::Array3;
use rand::rngs::SmallRng;
use rand::SeedableRng;

// Create multi-head attention with 2 heads and 64-dim embeddings
let mut rng = SmallRng::seed_from_u64(42);
let config = AttentionConfig {
    num_heads: 2,
    head_dim: 32,
    dropout_prob: 0.0,
    causal: false,
    scale: None,
};

let mha = MultiHeadAttention::new(64, config, &mut rng).unwrap();

// Forward pass with a batch of 2 samples, sequence length 3
let batch_size = 2;
let seq_len = 3;
let d_model = 64;
let input = Array3::<f64>::from_elem((batch_size, seq_len, d_model), 0.1).into_dyn();
let output = mha.forward(&input).unwrap();

// Output shape should match input shape
assert_eq!(output.shape(), input.shape());

Implementations§

Source§

impl<F: Float + Debug + ScalarOperand + 'static> MultiHeadAttention<F>

Source

pub fn new<R: Rng>( d_model: usize, config: AttentionConfig, rng: &mut R, ) -> Result<Self>

Create a new multi-head attention layer

§Arguments
  • d_model - Embedding dimension
  • config - Attention configuration
  • rng - Random number generator for weight initialization
§Returns
  • A new multi-head attention layer

Trait Implementations§

Source§

impl<F: Float + Debug + ScalarOperand + 'static> Clone for MultiHeadAttention<F>

Source§

fn clone(&self) -> Self

Returns a duplicate of the value. Read more
1.0.0 · Source§

const fn clone_from(&mut self, source: &Self)

Performs copy-assignment from source. Read more
Source§

impl<F: Debug + Float + Debug> Debug for MultiHeadAttention<F>

Source§

fn fmt(&self, f: &mut Formatter<'_>) -> Result

Formats the value using the given formatter. Read more
Source§

impl<F: Float + Debug + ScalarOperand + 'static> Layer<F> for MultiHeadAttention<F>

Source§

fn as_any(&self) -> &dyn Any

Get the layer as a dyn Any for downcasting Read more
Source§

fn as_any_mut(&mut self) -> &mut dyn Any

Get the layer as a mutable dyn Any for downcasting Read more
Source§

fn forward(&self, input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>>

Forward pass of the layer Read more
Source§

fn backward( &self, input: &Array<F, IxDyn>, _grad_output: &Array<F, IxDyn>, ) -> Result<Array<F, IxDyn>>

Backward pass of the layer to compute gradients Read more
Source§

fn update(&mut self, learning_rate: F) -> Result<()>

Update the layer parameters with the given gradients Read more
Source§

fn params(&self) -> Vec<Array<F, IxDyn>>

Get the parameters of the layer Read more
Source§

fn gradients(&self) -> Vec<Array<F, IxDyn>>

Get the gradients of the layer parameters Read more
Source§

fn set_gradients(&mut self, _gradients: &[Array<F, IxDyn>]) -> Result<()>

Set the gradients of the layer parameters Read more
Source§

fn set_params(&mut self, _params: &[Array<F, IxDyn>]) -> Result<()>

Set the parameters of the layer Read more
Source§

fn set_training(&mut self, _training: bool)

Set the layer to training mode (true) or evaluation mode (false) Read more
Source§

fn is_training(&self) -> bool

Get the current training mode Read more
Source§

fn layer_type(&self) -> &str

Get the type of the layer (e.g., “Dense”, “Conv2D”) Read more
Source§

fn parameter_count(&self) -> usize

Get the number of trainable parameters in this layer Read more
Source§

fn layer_description(&self) -> String

Get a detailed description of this layer Read more
Source§

impl<F: Float + Debug + ScalarOperand + 'static> ParamLayer<F> for MultiHeadAttention<F>

Source§

fn get_parameters(&self) -> Vec<&Array<F, IxDyn>>

Get the parameters of the layer as a vector of arrays
Source§

fn get_gradients(&self) -> Vec<&Array<F, IxDyn>>

Get the gradients of the parameters
Source§

fn set_parameters(&mut self, params: Vec<Array<F, IxDyn>>) -> Result<()>

Set the parameters of the layer

Auto Trait Implementations§

§

impl<F> !Freeze for MultiHeadAttention<F>

§

impl<F> !RefUnwindSafe for MultiHeadAttention<F>

§

impl<F> Send for MultiHeadAttention<F>
where F: Send,

§

impl<F> !Sync for MultiHeadAttention<F>

§

impl<F> Unpin for MultiHeadAttention<F>
where F: Unpin,

§

impl<F> UnwindSafe for MultiHeadAttention<F>

Blanket Implementations§

Source§

impl<T> Any for T
where T: 'static + ?Sized,

Source§

fn type_id(&self) -> TypeId

Gets the TypeId of self. Read more
Source§

impl<T> Borrow<T> for T
where T: ?Sized,

Source§

fn borrow(&self) -> &T

Immutably borrows from an owned value. Read more
Source§

impl<T> BorrowMut<T> for T
where T: ?Sized,

Source§

fn borrow_mut(&mut self) -> &mut T

Mutably borrows from an owned value. Read more
Source§

impl<T> CloneToUninit for T
where T: Clone,

Source§

unsafe fn clone_to_uninit(&self, dest: *mut u8)

🔬This is a nightly-only experimental API. (clone_to_uninit)
Performs copy-assignment from self to dest. Read more
Source§

impl<T> From<T> for T

Source§

fn from(t: T) -> T

Returns the argument unchanged.

Source§

impl<T, U> Into<U> for T
where U: From<T>,

Source§

fn into(self) -> U

Calls U::from(self).

That is, this conversion is whatever the implementation of From<T> for U chooses to do.

Source§

impl<T> IntoEither for T

Source§

fn into_either(self, into_left: bool) -> Either<Self, Self>

Converts self into a Left variant of Either<Self, Self> if into_left is true. Converts self into a Right variant of Either<Self, Self> otherwise. Read more
Source§

fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
where F: FnOnce(&Self) -> bool,

Converts self into a Left variant of Either<Self, Self> if into_left(&self) returns true. Converts self into a Right variant of Either<Self, Self> otherwise. Read more
Source§

impl<T> Pointable for T

Source§

const ALIGN: usize

The alignment of pointer.
Source§

type Init = T

The type for initializers.
Source§

unsafe fn init(init: <T as Pointable>::Init) -> usize

Initializes a with the given initializer. Read more
Source§

unsafe fn deref<'a>(ptr: usize) -> &'a T

Dereferences the given pointer. Read more
Source§

unsafe fn deref_mut<'a>(ptr: usize) -> &'a mut T

Mutably dereferences the given pointer. Read more
Source§

unsafe fn drop(ptr: usize)

Drops the object pointed to by the given pointer. Read more
Source§

impl<T> ToOwned for T
where T: Clone,

Source§

type Owned = T

The resulting type after obtaining ownership.
Source§

fn to_owned(&self) -> T

Creates owned data from borrowed data, usually by cloning. Read more
Source§

fn clone_into(&self, target: &mut T)

Uses borrowed data to replace owned data, usually by cloning. Read more
Source§

impl<T, U> TryFrom<U> for T
where U: Into<T>,

Source§

type Error = Infallible

The type returned in the event of a conversion error.
Source§

fn try_from(value: U) -> Result<T, <T as TryFrom<U>>::Error>

Performs the conversion.
Source§

impl<T, U> TryInto<U> for T
where U: TryFrom<T>,

Source§

type Error = <U as TryFrom<T>>::Error

The type returned in the event of a conversion error.
Source§

fn try_into(self) -> Result<U, <U as TryFrom<T>>::Error>

Performs the conversion.
Source§

impl<V, T> VZip<V> for T
where V: MultiLane<T>,

Source§

fn vzip(self) -> V