Skip to main content

split_heads

Function split_heads 

Source
pub fn split_heads(
    x: &Tensor,
    n_heads: usize,
    head_dim: usize,
) -> Result<Tensor>
Expand description

Reshape [1, seq, n_heads * head_dim][1, n_heads, seq, head_dim].