tiny_recursive_rs/layers/
embeddings.rs

1/// Embedding layer with automatic dtype casting
2use candle_core::{Module, Result, Tensor, DType};
3use candle_nn::{Embedding, VarBuilder};
4
5pub struct CastedEmbedding {
6    embedding: Embedding,
7    target_dtype: DType,
8}
9
10impl CastedEmbedding {
11    pub fn new(vocab_size: usize, hidden_size: usize, vb: VarBuilder, target_dtype: DType) -> Result<Self> {
12        let embedding = candle_nn::embedding(vocab_size, hidden_size, vb)?;
13        Ok(Self {
14            embedding,
15            target_dtype,
16        })
17    }
18
19    pub fn forward(&self, input: &Tensor) -> Result<Tensor> {
20        let output = self.embedding.forward(input)?;
21        if output.dtype() != self.target_dtype {
22            output.to_dtype(self.target_dtype)
23        } else {
24            Ok(output)
25        }
26    }
27}