smelte_rs/nn/layers/
embedding.rs

1use crate::traits::{Tensor, TensorOps};
2use crate::SmeltError;
3
4/// TODO
5#[derive(Clone)]
6pub struct Embedding<T: Tensor> {
7    weight: T,
8}
9
10impl<T: Tensor + TensorOps<T>> Embedding<T> {
11    /// TODO
12    pub fn new(weight: T) -> Self {
13        Self { weight }
14    }
15
16    /// TODO
17    pub fn forward(&self, ids: &[usize], out: &mut T) -> Result<(), SmeltError> {
18        T::select(ids, &self.weight, out)
19    }
20
21    /// TODO
22    pub fn weight(&self) -> &T {
23        &self.weight
24    }
25}
26
27#[cfg(test)]
28mod tests {
29    use super::*;
30    use crate::cpu::f32::Tensor;
31
32    #[test]
33    fn test_embedding() {
34        let weights = Tensor::zeros(vec![3, 2]);
35        let embedding = Embedding::new(weights);
36        let mut out = Tensor::zeros(vec![2, 2]);
37        embedding.forward(&[0, 1], &mut out).unwrap();
38    }
39
40    #[test]
41    fn test_embedding_errors() {
42        let weights = Tensor::zeros(vec![3, 2]);
43        let embedding = Embedding::new(weights);
44        let mut out = Tensor::zeros(vec![2, 2]);
45        assert!(embedding.forward(&[3], &mut out).is_err());
46    }
47}