Skip to main content

embedding_lookup

Function embedding_lookup 

Source
pub fn embedding_lookup(
    weight: &Tensor,
    indices: &Tensor,
) -> Result<Tensor, KernelError>
Expand description

Looks up embeddings from a weight matrix.

weight: [vocab_size, embed_dim] indices: [*] — flat tensor of integer indices (stored as f32)

Returns: [*indices_shape, embed_dim]