Skip to main content

flash_attention

Function flash_attention 

Source
pub fn flash_attention(
    query: &Tensor,
    key: &Tensor,
    value: &Tensor,
    mask: Option<&Tensor>,
) -> Result<Tensor, KernelError>
Expand description

Memory-efficient (flash) attention — same result as scaled_dot_product_attention but uses O(Br×Bc) peak memory instead of O(seq_q×seq_k).