pub fn flash_attention(
query: &Tensor,
key: &Tensor,
value: &Tensor,
mask: Option<&Tensor>,
) -> Result<Tensor, KernelError>Expand description
Memory-efficient (flash) attention — same result as scaled_dot_product_attention
but uses O(Br×Bc) peak memory instead of O(seq_q×seq_k).