smelte_rs/nn/layers/
embedding.rs1use crate::traits::{Tensor, TensorOps};
2use crate::SmeltError;
3
4#[derive(Clone)]
6pub struct Embedding<T: Tensor> {
7 weight: T,
8}
9
10impl<T: Tensor + TensorOps<T>> Embedding<T> {
11 pub fn new(weight: T) -> Self {
13 Self { weight }
14 }
15
16 pub fn forward(&self, ids: &[usize], out: &mut T) -> Result<(), SmeltError> {
18 T::select(ids, &self.weight, out)
19 }
20
21 pub fn weight(&self) -> &T {
23 &self.weight
24 }
25}
26
27#[cfg(test)]
28mod tests {
29 use super::*;
30 use crate::cpu::f32::Tensor;
31
32 #[test]
33 fn test_embedding() {
34 let weights = Tensor::zeros(vec![3, 2]);
35 let embedding = Embedding::new(weights);
36 let mut out = Tensor::zeros(vec![2, 2]);
37 embedding.forward(&[0, 1], &mut out).unwrap();
38 }
39
40 #[test]
41 fn test_embedding_errors() {
42 let weights = Tensor::zeros(vec![3, 2]);
43 let embedding = Embedding::new(weights);
44 let mut out = Tensor::zeros(vec![2, 2]);
45 assert!(embedding.forward(&[3], &mut out).is_err());
46 }
47}