pub fn masked_attention<F>(
query: &ArrayView3<'_, F>,
key: &ArrayView3<'_, F>,
value: &ArrayView3<'_, F>,
mask: &AttentionMask,
scale: F,
) -> LinalgResult<Array3<F>>
Expand description
Masked Attention - Applies a custom mask to attention
§Arguments
query
- Query tensor of shape [batch_size, seq_len_q, d_model]key
- Key tensor of shape [batch_size, seq_len_k, d_model]value
- Value tensor of shape [batch_size, seq_len_k, d_model]mask
- The mask to apply to attention weightsscale
- Scaling factor for dot product
§Returns
- Output tensor of shape [batch_size, seq_len_q, d_model]