1use zyx::{DType, Tensor, ZyxError};
5use zyx_derive::Module;
6
7#[derive(Debug, Module)]
9#[cfg_attr(feature = "py", pyo3::pyclass)]
10pub struct Embedding {
11 pub vocab_size: u64,
13 pub embed_size: u64,
15 pub weight: Tensor,
17 pub arange: Tensor,
19}
20
21impl Embedding {
22 pub fn new(vocab_size: u64, embed_size: u64, dtype: DType) -> Result<Embedding, ZyxError> {
24 Ok(Embedding {
25 vocab_size,
26 embed_size,
27 weight: Tensor::glorot_uniform([vocab_size, embed_size], dtype)?
28 .reshape([1, 1, vocab_size, embed_size])?,
29 arange: Tensor::arange(0, vocab_size as i64, 1)?
30 .reshape([1, 1, vocab_size, 1])?
31 .cast(dtype),
32 })
33 }
34
35 pub fn from_params(weight: Tensor) -> Result<Embedding, ZyxError> {
37 let sh = weight.shape();
38 assert_eq!(sh.len(), 2);
39 Ok(Embedding {
40 vocab_size: sh[0],
41 embed_size: sh[1],
42 arange: Tensor::arange(0, sh[0] as i64, 1)?
43 .reshape([1, 1, sh[0], 1])?
44 .cast(weight.dtype()),
45 weight,
46 })
47 }
48
49 pub fn forward(&self, x: impl Into<Tensor>) -> Result<Tensor, ZyxError> {
51 let x: Tensor = x.into();
52 let x_sh = x.shape();
53 if x.numel() == 0 {
54 return Ok(Tensor::zeros(
55 x_sh.iter()
56 .copied()
57 .chain([self.embed_size])
58 .collect::<Vec<u64>>(),
59 x.dtype(),
60 ));
61 }
62 let xdt = x.dtype();
63 let wdt = self.weight.dtype();
64 if xdt != wdt {
65 return Err(ZyxError::DTypeError(
66 format!("Embedding::forward input x has dtype {xdt} but weight has dtype {wdt}")
67 .into(),
68 ));
69 }
70 let big_shp: Vec<u64> = x_sh
71 .iter()
72 .copied()
73 .chain([self.vocab_size, self.embed_size])
74 .collect();
75 let arange = self.arange.expand(big_shp.clone())?;
76 let idx = x
77 .reshape(x_sh.into_iter().chain([1, 1]).collect::<Vec<u64>>())?
78 .expand(big_shp.clone())?;
79 let vals = self.weight.expand(big_shp)?;
80 (arange.equal(idx)?.cast(xdt) * vals).sum([2])
81 }
82}