pub fn shard_tensor(
tensor: &Tensor,
num_shards: usize,
) -> Result<Vec<Tensor>, ModelError>Expand description
Shard a tensor along its first dimension into num_shards roughly equal parts.
This is the core primitive for Fully Sharded Data Parallel (FSDP)-style
training: large parameter tensors are split across workers so that each
worker stores only its shard. Before a forward or backward pass the
shards are gathered (see gather_shards), and after the pass only the
local shard’s gradients are kept.
Each shard is a separate Tensor that can be placed on a different device.
If num_shards does not evenly divide the first dimension, earlier shards
receive one extra row.