Skip to main content

zyx_nn/
embedding.rs

1// Copyright (C) 2025 zk4x
2// SPDX-License-Identifier: LGPL-3.0-only
3
4use zyx::{DType, Tensor, ZyxError};
5use zyx_derive::Module;
6
7/// Embedding layer
8#[derive(Debug, Module)]
9#[cfg_attr(feature = "py", pyo3::pyclass)]
10pub struct Embedding {
11    /// Vocabulary size
12    pub vocab_size: u64,
13    /// Embedding size
14    pub embed_size: u64,
15    /// Weight
16    pub weight: Tensor,
17    /// Arange
18    pub arange: Tensor,
19}
20
21impl Embedding {
22    /// new embedding layer
23    pub fn new(vocab_size: u64, embed_size: u64, dtype: DType) -> Result<Embedding, ZyxError> {
24        Ok(Embedding {
25            vocab_size,
26            embed_size,
27            weight: Tensor::glorot_uniform([vocab_size, embed_size], dtype)?
28                .reshape([1, 1, vocab_size, embed_size])?,
29            arange: Tensor::arange(0, vocab_size as i64, 1)?
30                .reshape([1, 1, vocab_size, 1])?
31                .cast(dtype),
32        })
33    }
34
35    /// Initialize embedding using only weight
36    pub fn from_params(weight: Tensor) -> Result<Embedding, ZyxError> {
37        let sh = weight.shape();
38        assert_eq!(sh.len(), 2);
39        Ok(Embedding {
40            vocab_size: sh[0],
41            embed_size: sh[1],
42            arange: Tensor::arange(0, sh[0] as i64, 1)?
43                .reshape([1, 1, sh[0], 1])?
44                .cast(weight.dtype()),
45            weight,
46        })
47    }
48
49    /// Forward embedding layer
50    pub fn forward(&self, x: impl Into<Tensor>) -> Result<Tensor, ZyxError> {
51        let x: Tensor = x.into();
52        let x_sh = x.shape();
53        if x.numel() == 0 {
54            return Ok(Tensor::zeros(
55                x_sh.iter()
56                    .copied()
57                    .chain([self.embed_size])
58                    .collect::<Vec<u64>>(),
59                x.dtype(),
60            ));
61        }
62        let xdt = x.dtype();
63        let wdt = self.weight.dtype();
64        if xdt != wdt {
65            return Err(ZyxError::DTypeError(
66                format!("Embedding::forward input x has dtype {xdt} but weight has dtype {wdt}")
67                    .into(),
68            ));
69        }
70        let big_shp: Vec<u64> = x_sh
71            .iter()
72            .copied()
73            .chain([self.vocab_size, self.embed_size])
74            .collect();
75        let arange = self.arange.expand(big_shp.clone())?;
76        let idx = x
77            .reshape(x_sh.into_iter().chain([1, 1]).collect::<Vec<u64>>())?
78            .expand(big_shp.clone())?;
79        let vals = self.weight.expand(big_shp)?;
80        (arange.equal(idx)?.cast(xdt) * vals).sum([2])
81    }
82}