tiny_recursive_rs/layers/
embeddings.rs1use candle_core::{Module, Result, Tensor, DType};
3use candle_nn::{Embedding, VarBuilder};
4
5pub struct CastedEmbedding {
6 embedding: Embedding,
7 target_dtype: DType,
8}
9
10impl CastedEmbedding {
11 pub fn new(vocab_size: usize, hidden_size: usize, vb: VarBuilder, target_dtype: DType) -> Result<Self> {
12 let embedding = candle_nn::embedding(vocab_size, hidden_size, vb)?;
13 Ok(Self {
14 embedding,
15 target_dtype,
16 })
17 }
18
19 pub fn forward(&self, input: &Tensor) -> Result<Tensor> {
20 let output = self.embedding.forward(input)?;
21 if output.dtype() != self.target_dtype {
22 output.to_dtype(self.target_dtype)
23 } else {
24 Ok(output)
25 }
26 }
27}