pub fn attention<F>(
query: &ArrayView3<'_, F>,
key: &ArrayView3<'_, F>,
value: &ArrayView3<'_, F>,
mask: Option<&AttentionMask>,
scale: F,
) -> LinalgResult<Array3<F>>
Expand description
Basic attention function - the building block for all other attention variants
This implements the standard attention mechanism: Attention(Q, K, V) = softmax(QK^T / sqrt(d_k)) * V
§Arguments
query
- Query tensor of shape [batch_size, seq_len_q, d_model]key
- Key tensor of shape [batch_size, seq_len_k, d_model]value
- Value tensor of shape [batch_size, seq_len_k, d_model]mask
- Optional mask to apply to attention weightsscale
- Scaling factor for dot product (default is 1/sqrt(d_k))
§Returns
- Output tensor of shape [batch_size, seq_len_q, d_model]