pub fn split_heads( x: &Tensor, n_heads: usize, head_dim: usize, ) -> Result<Tensor>
Reshape [1, seq, n_heads * head_dim] → [1, n_heads, seq, head_dim].
[1, seq, n_heads * head_dim]
[1, n_heads, seq, head_dim]