Skip to main content

shard_tensor

Function shard_tensor 

Source
pub fn shard_tensor(
    tensor: &Tensor,
    num_shards: usize,
) -> Result<Vec<Tensor>, ModelError>
Expand description

Shard a tensor along its first dimension into num_shards roughly equal parts.

This is the core primitive for Fully Sharded Data Parallel (FSDP)-style training: large parameter tensors are split across workers so that each worker stores only its shard. Before a forward or backward pass the shards are gathered (see gather_shards), and after the pass only the local shard’s gradients are kept.

Each shard is a separate Tensor that can be placed on a different device. If num_shards does not evenly divide the first dimension, earlier shards receive one extra row.