Function scaled_dot_product_attention

Source
pub fn scaled_dot_product_attention<F>(
    query: &ArrayView3<'_, F>,
    key: &ArrayView3<'_, F>,
    value: &ArrayView3<'_, F>,
    mask: Option<&AttentionMask>,
    scale: F,
) -> LinalgResult<Array3<F>>
where F: Float + Add + Mul + Div + Sub + NumAssignOps + Zero + Debug + 'static,
Expand description

Scaled Dot-Product Attention

The standard attention mechanism used in Transformer models: Attention(Q, K, V) = softmax(QK^T / sqrt(d_k)) * V

§Arguments

  • query - Query tensor of shape [batch_size, seq_len_q, d_model]
  • key - Key tensor of shape [batch_size, seq_len_k, d_model]
  • value - Value tensor of shape [batch_size, seq_len_k, d_model]
  • mask - Optional mask to apply to attention weights
  • scale - Scaling factor for dot product (default is 1/sqrt(d_k))

§Returns

  • Output tensor of shape [batch_size, seq_len_q, d_model]