Function masked_attention

Source
pub fn masked_attention<F>(
    query: &ArrayView3<'_, F>,
    key: &ArrayView3<'_, F>,
    value: &ArrayView3<'_, F>,
    mask: &AttentionMask,
    scale: F,
) -> LinalgResult<Array3<F>>
where F: Float + Add + Mul + Div + Sub + NumAssignOps + Zero + Debug,
Expand description

Masked Attention - Applies a custom mask to attention

§Arguments

  • query - Query tensor of shape [batch_size, seq_len_q, d_model]
  • key - Key tensor of shape [batch_size, seq_len_k, d_model]
  • value - Value tensor of shape [batch_size, seq_len_k, d_model]
  • mask - The mask to apply to attention weights
  • scale - Scaling factor for dot product

§Returns

  • Output tensor of shape [batch_size, seq_len_q, d_model]