tch_plus/nn/
sparse.rs

1//! Sparse Layers
2use crate::Tensor;
3use std::borrow::Borrow;
4
5/// Configuration option for an embedding layer.
6#[derive(Debug, Clone, Copy)]
7pub struct EmbeddingConfig {
8    pub sparse: bool,
9    pub scale_grad_by_freq: bool,
10    pub ws_init: super::Init,
11    pub padding_idx: i64,
12}
13
14impl Default for EmbeddingConfig {
15    fn default() -> Self {
16        EmbeddingConfig {
17            sparse: false,
18            scale_grad_by_freq: false,
19            ws_init: super::Init::Randn { mean: 0., stdev: 1. },
20            padding_idx: -1,
21        }
22    }
23}
24
25/// An embedding layer.
26///
27/// An embedding layer acts as a simple lookup table that stores embeddings.
28/// This is commonly used to store word embeddings.
29#[derive(Debug)]
30pub struct Embedding {
31    pub ws: Tensor,
32    config: EmbeddingConfig,
33}
34
35pub fn embedding<'a, T: Borrow<super::Path<'a>>>(
36    vs: T,
37    num_embeddings: i64,
38    embedding_dim: i64,
39    config: EmbeddingConfig,
40) -> Embedding {
41    let vs = vs.borrow();
42    Embedding { ws: vs.var("weight", &[num_embeddings, embedding_dim], config.ws_init), config }
43}
44
45impl super::module::Module for Embedding {
46    fn forward(&self, xs: &Tensor) -> Tensor {
47        Tensor::embedding(
48            &self.ws,
49            xs,
50            self.config.padding_idx,
51            self.config.scale_grad_by_freq,
52            self.config.sparse,
53        )
54    }
55}