pub struct MultiHeadAttention<F: Float + Debug> { /* private fields */ }
Expand description
Multi-head attention layer as used in transformer architectures
This layer performs the attention operation described in “Attention Is All You Need” by Vaswani et al. It projects the queries, keys, and values into multiple heads, computes scaled dot-product attention for each head, concatenates the results, and projects the result back to the original dimension.
§Examples
use scirs2_neural::layers::{MultiHeadAttention, Layer};
use scirs2_neural::layers::AttentionConfig;
use ndarray::Array3;
use rand::rngs::SmallRng;
use rand::SeedableRng;
// Create multi-head attention with 2 heads and 64-dim embeddings
let mut rng = SmallRng::seed_from_u64(42);
let config = AttentionConfig {
num_heads: 2,
head_dim: 32,
dropout_prob: 0.0,
causal: false,
scale: None,
};
let mha = MultiHeadAttention::new(64, config, &mut rng).unwrap();
// Forward pass with a batch of 2 samples, sequence length 3
let batch_size = 2;
let seq_len = 3;
let d_model = 64;
let input = Array3::<f64>::from_elem((batch_size, seq_len, d_model), 0.1).into_dyn();
let output = mha.forward(&input).unwrap();
// Output shape should match input shape
assert_eq!(output.shape(), input.shape());
Implementations§
Source§impl<F: Float + Debug + ScalarOperand + 'static> MultiHeadAttention<F>
impl<F: Float + Debug + ScalarOperand + 'static> MultiHeadAttention<F>
Trait Implementations§
Source§impl<F: Float + Debug + ScalarOperand + 'static> Clone for MultiHeadAttention<F>
impl<F: Float + Debug + ScalarOperand + 'static> Clone for MultiHeadAttention<F>
Source§impl<F: Float + Debug + ScalarOperand + 'static> Layer<F> for MultiHeadAttention<F>
impl<F: Float + Debug + ScalarOperand + 'static> Layer<F> for MultiHeadAttention<F>
Source§fn as_any_mut(&mut self) -> &mut dyn Any
fn as_any_mut(&mut self) -> &mut dyn Any
Get the layer as a mutable dyn Any for downcasting Read more
Source§fn forward(&self, input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>>
fn forward(&self, input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>>
Forward pass of the layer Read more
Source§fn backward(
&self,
input: &Array<F, IxDyn>,
_grad_output: &Array<F, IxDyn>,
) -> Result<Array<F, IxDyn>>
fn backward( &self, input: &Array<F, IxDyn>, _grad_output: &Array<F, IxDyn>, ) -> Result<Array<F, IxDyn>>
Backward pass of the layer to compute gradients Read more
Source§fn update(&mut self, learning_rate: F) -> Result<()>
fn update(&mut self, learning_rate: F) -> Result<()>
Update the layer parameters with the given gradients Read more
Source§fn gradients(&self) -> Vec<Array<F, IxDyn>> ⓘ
fn gradients(&self) -> Vec<Array<F, IxDyn>> ⓘ
Get the gradients of the layer parameters Read more
Source§fn set_gradients(&mut self, _gradients: &[Array<F, IxDyn>]) -> Result<()>
fn set_gradients(&mut self, _gradients: &[Array<F, IxDyn>]) -> Result<()>
Set the gradients of the layer parameters Read more
Source§fn set_params(&mut self, _params: &[Array<F, IxDyn>]) -> Result<()>
fn set_params(&mut self, _params: &[Array<F, IxDyn>]) -> Result<()>
Set the parameters of the layer Read more
Source§fn set_training(&mut self, _training: bool)
fn set_training(&mut self, _training: bool)
Set the layer to training mode (true) or evaluation mode (false) Read more
Source§fn is_training(&self) -> bool
fn is_training(&self) -> bool
Get the current training mode Read more
Source§fn layer_type(&self) -> &str
fn layer_type(&self) -> &str
Get the type of the layer (e.g., “Dense”, “Conv2D”) Read more
Source§fn parameter_count(&self) -> usize
fn parameter_count(&self) -> usize
Get the number of trainable parameters in this layer Read more
Source§fn layer_description(&self) -> String
fn layer_description(&self) -> String
Get a detailed description of this layer Read more
Source§impl<F: Float + Debug + ScalarOperand + 'static> ParamLayer<F> for MultiHeadAttention<F>
impl<F: Float + Debug + ScalarOperand + 'static> ParamLayer<F> for MultiHeadAttention<F>
Auto Trait Implementations§
impl<F> !Freeze for MultiHeadAttention<F>
impl<F> !RefUnwindSafe for MultiHeadAttention<F>
impl<F> Send for MultiHeadAttention<F>where
F: Send,
impl<F> !Sync for MultiHeadAttention<F>
impl<F> Unpin for MultiHeadAttention<F>where
F: Unpin,
impl<F> UnwindSafe for MultiHeadAttention<F>where
F: UnwindSafe + RefUnwindSafe,
Blanket Implementations§
Source§impl<T> BorrowMut<T> for Twhere
T: ?Sized,
impl<T> BorrowMut<T> for Twhere
T: ?Sized,
Source§fn borrow_mut(&mut self) -> &mut T
fn borrow_mut(&mut self) -> &mut T
Mutably borrows from an owned value. Read more
Source§impl<T> CloneToUninit for Twhere
T: Clone,
impl<T> CloneToUninit for Twhere
T: Clone,
Source§impl<T> IntoEither for T
impl<T> IntoEither for T
Source§fn into_either(self, into_left: bool) -> Either<Self, Self>
fn into_either(self, into_left: bool) -> Either<Self, Self>
Converts
self
into a Left
variant of Either<Self, Self>
if into_left
is true
.
Converts self
into a Right
variant of Either<Self, Self>
otherwise. Read moreSource§fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
Converts
self
into a Left
variant of Either<Self, Self>
if into_left(&self)
returns true
.
Converts self
into a Right
variant of Either<Self, Self>
otherwise. Read more