1use crate::Tensor;
3use std::borrow::Borrow;
4
5#[derive(Debug, Clone, Copy)]
7pub struct EmbeddingConfig {
8 pub sparse: bool,
9 pub scale_grad_by_freq: bool,
10 pub ws_init: super::Init,
11 pub padding_idx: i64,
12}
13
14impl Default for EmbeddingConfig {
15 fn default() -> Self {
16 EmbeddingConfig {
17 sparse: false,
18 scale_grad_by_freq: false,
19 ws_init: super::Init::Randn { mean: 0., stdev: 1. },
20 padding_idx: -1,
21 }
22 }
23}
24
25#[derive(Debug)]
30pub struct Embedding {
31 pub ws: Tensor,
32 config: EmbeddingConfig,
33}
34
35pub fn embedding<'a, T: Borrow<super::Path<'a>>>(
36 vs: T,
37 num_embeddings: i64,
38 embedding_dim: i64,
39 config: EmbeddingConfig,
40) -> Embedding {
41 let vs = vs.borrow();
42 Embedding { ws: vs.var("weight", &[num_embeddings, embedding_dim], config.ws_init), config }
43}
44
45impl super::module::Module for Embedding {
46 fn forward(&self, xs: &Tensor) -> Tensor {
47 Tensor::embedding(
48 &self.ws,
49 xs,
50 self.config.padding_idx,
51 self.config.scale_grad_by_freq,
52 self.config.sparse,
53 )
54 }
55}